spark dataframe groupby收集\列表并用0填充缺少的valeus

8ljdwjyq  于 2021-07-09  发布在  Spark
关注(0)|答案(1)|浏览(297)

我有一个这样的Dataframe

+------+--------+--------+
|    id|category|quantity|
+------+--------+--------+
|merch1|   fruit|    20.0|
|merch1| veggies|   300.0|
|merch1|   diary|    10.0|
|merch1|organics|    12.0|
|merch1|  frozen|    11.0|
|merch2|   fruit|     6.0|
|merch2|   diary|     6.0|
|merch2|  frozen|     8.0|
|merch3| veggies|    13.0|
|merch3|organics|     4.0|
|merch3|  frozen|    10.0|
|merch4|   fruit|    28.0|
|merch4|organics|    11.0|
+------+--------+--------+

我想按id分组并生成一个向量,作为一个有序的collect\u列表,这样如果类别不存在,它就用0.0填充。不同的类别是

+--------+                                                                      
|category|
+--------+
|  frozen|
|   diary|
| veggies|
|organics|
|   fruit|
+--------+

例如,对于merch1,因为所有的东西都存在,所以它是merch1的向量 [11.0, 10.0, 300.0, 12.0, 20.0] 对merch2来说是的 [8.0, 6.0, 0.0, 0.0, 6.0] 对merch3来说是的 [10.0, 0.0, 13.0, 4.0, 0.0] 对merch4来说就是这样 [0.0, 0.0, 0.0, 28.0, 11.0] 因此,我要寻找的最后一个Dataframe是

+------+--------+----------------------+
|    id| vector                        |
+------+--------+----------------------+
|merch1|[11.0, 10.0, 300.0, 12.0, 20.0]|
|merch2|[8.0, 6.0, 0.0, 0.0, 6.0]      |
|merch3|[10.0, 0.0, 13.0, 4.0, 0.0]    |
|merch4|[0.0, 0.0, 0.0, 28.0, 11.0]    |
+------+-------------------------------+
mwngjboj

mwngjboj1#

我们可以通过两个步骤来实现,在组级别将行转换为列(这里是 id )使用 pivot 使用 array sql函数按预期创建列表。

from pyspark.sql import function as f
df.show()
+------+--------+--------+
|    id|category|quantity|
+------+--------+--------+
|merch1|   fruit|    20.0|
|merch1| veggies|   300.0|
|merch1|   diary|    10.0|
|merch1|organics|    12.0|
|merch1|  frozen|    11.0|
|merch2|   fruit|     6.0|
|merch2|   diary|     6.0|
|merch2|  frozen|     8.0|
|merch3| veggies|    13.0|
|merch3|organics|     4.0|
|merch3|  frozen|    10.0|
|merch4|   fruit|    28.0|
|merch4|organics|    11.0|
+------+--------+--------+  

df1 = df.groupby('id').pivot('category').agg(f.first('quantity')).fillna(0)
df1.show()
+------+-----+------+-----+--------+-------+
|    id|diary|frozen|fruit|organics|veggies|
+------+-----+------+-----+--------+-------+
|merch2|  6.0|   8.0|  6.0|     0.0|    0.0|
|merch4|  0.0|   0.0| 28.0|    11.0|    0.0|
|merch1| 10.0|  11.0| 20.0|    12.0|  300.0|
|merch3|  0.0|  10.0|  0.0|     4.0|   13.0|
+------+-----+------+-----+--------+-------+

df1.select('id',f.array(df1.columns[1:]).name('vector')).show(truncate=False)
+------+-------------------------------+
|id    |vector                         |
+------+-------------------------------+
|merch2|[6.0, 8.0, 6.0, 0.0, 0.0]      |
|merch4|[0.0, 0.0, 28.0, 11.0, 0.0]    |
|merch1|[10.0, 11.0, 20.0, 12.0, 300.0]|
|merch3|[0.0, 10.0, 0.0, 4.0, 13.0]    |
+------+-------------------------------+

相关问题