我想知道如何重构一些scala代码,使之更优雅、更地道。
我有一个函数
def joinDataFramesOnColumns(joinColumns: Seq[String]) : org.apache.spark.sql.DataFrame
在一个 Seq[org.apache.spark.sql.DataFrame]
把他们连在一起 joinColumns
. 函数定义如下:
implicit class SequenceOfDataFrames(dataFrames: Seq[DataFrame]){
def joinDataFramesOnColumns(joinColumns: Seq[String]) : DataFrame = {
val emptyDataFrame = SparkSession.builder().getOrCreate().emptyDataFrame
val nonEmptyDataFrames = dataFrames.filter(_ != emptyDataFrame)
if (nonEmptyDataFrames.isEmpty){
emptyDataFrame
}
else {
if (joinColumns.isEmpty) {
return nonEmptyDataFrames.reduce(_.crossJoin(_))
}
nonEmptyDataFrames.reduce(_.join(_, joinColumns))
}
}
}
我有一些单元测试都成功了:
class FeatureGeneratorDataFrameExtensionsTest extends WordSpec {
val fruitValues = Seq(
Row(0, "BasketA", "Bananas", "Jack"),
Row(2, "BasketB", "Oranges", "Jack"),
Row(2, "BasketC", "Oranges", "Jill"),
Row(3, "BasketD", "Oranges", "Jack"),
Row(4, "BasketE", "Oranges", "Jack"),
Row(4, "BasketE", "Apples", "Jack"),
Row(4, "BasketF", "Bananas", "Jill")
)
val schema = List(
StructField("weeksPrior", IntegerType, true),
StructField("basket", StringType, true),
StructField("Product", StringType, true),
StructField("Customer", StringType, true)
)
val fruitDf = spark.createDataFrame(
spark.sparkContext.parallelize(fruitValues),
StructType(schema)
).withColumn("Date", udfDateSubWeeks(lit(dayPriorToAsAt), col("weeksPrior")))
"FeatureGenerator.SequenceOfDataFrames" should {
"join multiple dataframes on a specified set of columns" in {
val sequenceOfDataFrames = Seq[DataFrame](
fruitDf.withColumnRenamed("weeksPrior", "weeksPrior1"),
fruitDf.withColumnRenamed("weeksPrior", "weeksPrior2"),
fruitDf.withColumnRenamed("weeksPrior", "weeksPrior3"),
fruitDf.withColumnRenamed("weeksPrior", "weeksPrior4"),
fruitDf.withColumnRenamed("weeksPrior", "weeksPrior5")
)
val joinedDataFrames = sequenceOfDataFrames.joinDataFramesOnColumns(Seq("basket", "Product", "Customer", "Date"))
assert(joinedDataFrames.columns.length === 9)
assert(joinedDataFrames.columns.contains("basket"))
assert(joinedDataFrames.columns.contains("Product"))
assert(joinedDataFrames.columns.contains("Customer"))
assert(joinedDataFrames.columns.contains("Date"))
assert(joinedDataFrames.columns.contains("weeksPrior1"))
assert(joinedDataFrames.columns.contains("weeksPrior2"))
assert(joinedDataFrames.columns.contains("weeksPrior3"))
assert(joinedDataFrames.columns.contains("weeksPrior4"))
assert(joinedDataFrames.columns.contains("weeksPrior5"))
}
"when passed a list of one dataframe return that same dataframe" in {
val sequenceOfDataFrames = Seq[DataFrame](fruitDf)
val joinedDataFrame = sequenceOfDataFrames.joinDataFramesOnColumns(Seq("basket", "Product"))
assert(joinedDataFrame.columns.sorted === fruitDf.columns.sorted)
assert(joinedDataFrame.count === fruitDf.count)
}
"when passed an empty list of dataframes return an empty dataframe" in {
val joinedDataFrame = Seq[DataFrame]().joinDataFramesOnColumns(Seq("basket"))
assert(joinedDataFrame === spark.emptyDataFrame)
}
"when passed an empty list of joinColumns return the dataframes crossjoined" in {
val sequenceOfDataFrames = Seq[DataFrame](fruitDf,fruitDf, fruitDf)
val joinedDataFrame = sequenceOfDataFrames.joinDataFramesOnColumns(Seq[String]())
assert(joinedDataFrame.count === scala.math.pow(fruitDf.count, sequenceOfDataFrames.size))
assert(joinedDataFrame.columns.size === fruitDf.columns.size * sequenceOfDataFrames.size)
}
}
}
这一切都很好,直到它开始错误,由于这个Spark错误:https://issues.apache.org/jira/browse/spark-25150 当连接列具有相同名称时,在某些情况下可能会导致错误。
解决方法是将列别名为其他内容,因此我重新编写了类似so的函数,该函数为联接列别名,执行联接,然后将它们重命名回:
implicit class SequenceOfDataFrames(dataFrames: Seq[DataFrame]){
def joinDataFramesOnColumns(joinColumns: Seq[String]) : DataFrame = {
val emptyDataFrame = SparkSession.builder().getOrCreate().emptyDataFrame
val nonEmptyDataFrames = dataFrames.filter(_ != emptyDataFrame)
if (nonEmptyDataFrames.isEmpty){
emptyDataFrame
}
else {
if (joinColumns.isEmpty) {
return nonEmptyDataFrames.reduce(_.crossJoin(_))
}
/*
The horrible, gnarly, unelegent code below would ideally exist simply as:
nonEmptyDataFrames.reduce(_.join(_, joinColumns))
however that will fail in certain specific circumstances due to a bug in spark,
see https://issues.apache.org/jira/browse/SPARK-25150 for details
*/
val aliasSuffix = "_aliased"
val aliasedJoinColumns = joinColumns.map(joinColumn => joinColumn+aliasSuffix)
var aliasedNonEmptyDataFrames: Seq[DataFrame] = Seq()
nonEmptyDataFrames.foreach(
nonEmptyDataFrame =>{
var tempNonEmptyDataFrame = nonEmptyDataFrame
joinColumns.foreach(
joinColumn => {
tempNonEmptyDataFrame = tempNonEmptyDataFrame.withColumnRenamed(joinColumn, joinColumn+aliasSuffix)
}
)
aliasedNonEmptyDataFrames = aliasedNonEmptyDataFrames :+ tempNonEmptyDataFrame
}
)
var joinedAliasedNonEmptyDataFrames = aliasedNonEmptyDataFrames.reduce(_.join(_, aliasedJoinColumns))
joinColumns.foreach(
joinColumn => joinedAliasedNonEmptyDataFrames = joinedAliasedNonEmptyDataFrames.withColumnRenamed(
joinColumn+aliasSuffix, joinColumn
)
)
joinedAliasedNonEmptyDataFrames
}
}
}
考试还是通过了,所以我对它相当满意,但我正在看那些 var
以及将结果赋给它的循环 var
在每次迭代中。。。并发现它们相当不雅,相当丑陋,尤其是与原始版本的功能相比。我觉得一定有办法写这个,这样我就不用用了 var
s、 但经过反复试验,这是我能做的最好的了。
有人能提出一个更优雅的解决方案吗?作为一个scala开发新手,它将帮助我更加熟悉解决此类问题的惯用方法。
对于代码的其他部分(如测试)的任何建设性意见也将受到欢迎
1条答案
按热度按时间3bygqnnd1#
感谢@duelist,他的使用foldleft()的建议让我了解了scala中的foldleft是如何在dataframe上工作的?这反过来又导致我像这样修改代码以消除
var
学生:我本可以把两种说法合二为一,从而消除
val joinedAliasedNonEmptyDataFrames
但我更喜欢用这种过渡方式带来的清晰val
.