spark:按给定年份列表的聚合值

h9vpoimq  于 2021-07-12  发布在  Spark
关注(0)|答案(2)|浏览(285)

我是scala的新手,假设我有一个数据集:

>>> ds.show()
+--------------+-----------------+-------------+
|year          |nb_product_sold  | system_year |
+--------------+-----------------+-------------+
|2010          |     1           | 2012  |
|2012          |     2           | 2012  |
|2012          |     4           | 2012  |
|2015          |     3           | 2012  |
|2019          |     4           | 2012  |
|2021          |     5           | 2012  |
+--------------+-----------------+-------+

我有一个 List<Integer> years = {1, 3, 8} ,这意味着 x 年后 system_year 一年。目标是计算每种产品的总销售量 year 之后 system_year .
换言之,我必须计算2013年、2015年、2020年的产品销售总额。
输出数据集应如下所示:

+-------+-----------------------+
|  year |    total_product_sold |
+-------+-----------------------+
| 1     |     6                 | -> 2012 - 2013 6 products sold
| 3     |     9                 | -> 2012 - 2015 9 products sold
| 8     |     13                | -> 2012 - 2020 13 products sold
+-------+-----------------------+

我想知道在斯卡拉怎么做?我应该用吗 groupBy() 在这种情况下?

gcxthw6b

gcxthw6b1#

如果年份范围没有重叠,您可以使用groupby case/when。但在这里,您需要每年执行一次groupby,然后合并3个分组的Dataframe:

val years = List(1, 3, 8)

val result = years.map{ y =>
    df.filter($"year".between($"system_year", $"system_year" + y))
      .groupBy(lit(y).as("year"))
      .agg(sum($"nb_product_sold").as("total_product_sold"))
  }.reduce(_ union _)

result.show
//+----+------------------+
//|year|total_product_sold|
//+----+------------------+
//|   1|                 6|
//|   3|                 9|
//|   8|                13|
//+----+------------------+
rkkpypqq

rkkpypqq2#

可能有多种方式做事情,比我展示给你的更有效率,但它适用于你的用例。

//Sample Data
val df = Seq((2010,1,2012),(2012,2,2012),(2012,4,2012),(2015,3,2012),(2019,4,2012),(2021,5,2012)).toDF("year","nb_product_sold","system_year")
//taking the difference of the years from system year
val df1 = df.withColumn("Difference",$"year" - $"system_year")
//getting the running total for all years present in the dataframe by partitioning
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._
val w = Window.partitionBy("year").orderBy("year") 
val df2 = df1.withColumn("runningsum", sum("nb_product_sold").over(w)).withColumn("yearlist",lit(0)).dropDuplicates("year","system_year","Difference")
//creating Years list 
val years = List(1, 3, 8)
//creating a dataframe with total count for each year and union of all the dataframe and removing duplicates.
var df3= spark.createDataFrame(sc.emptyRDD[Row], df2.schema)
for (year <- years){
  val innerdf = df2.filter($"Difference" >= year -1 && $"Difference" <= year).withColumn("yearlist",lit(year))
  df3 = df3.union(innerdf)
}
//again doing partition by system date and doing the sum for all the years as per requirement
val w1 = Window.partitionBy("system_year").orderBy("year")
val finaldf = df3.withColumn("total_product_sold", sum("runningsum").over(w1)).select("yearlist","total_product_sold")

您可以看到如下输出:

相关问题