假设我有一个数组列 group_ids
```
+-------+----------+
|user_id|group_ids |
+-------+----------+
|1 |[5, 8] |
|3 |[1, 2, 3] |
|2 |[1, 4] |
+-------+----------+
架构:
root
|-- user_id: integer (nullable = false)
|-- group_ids: array (nullable = false)
| |-- element: integer (containsNull = false)
我想得到所有成对的组合:
+-------+------------------------+
|user_id|group_ids |
+-------+------------------------+
|1 |5, 8 |
|3 |[[1, 2], [1, 3], [2, 3]]|
|2 |1, 4 |
+-------+------------------------+
到目前为止,我用自定义项创建了最简单的解决方案:
spark.udf.register("permutate", udf((xs: Seq[Int]) => xs.combinations(2).toSeq))
dataset.withColumn("group_ids", expr("permutate(group_ids)"))
我要找的是通过spark内置函数实现的东西。有没有一种方法可以在没有自定义项的情况下实现相同的代码?
3条答案
按热度按时间ru9i0ody1#
一些高阶函数可以做到这一点。需要Spark>=2.4。
jyztefdp2#
基于
explode
以及joins
解决方案ikfrs5lh3#
可以得到列的最大大小
group_ids
. 然后,在范围内使用组合(1 - maxSize)
与when
表达式从原始数组创建子数组组合,并最终从结果数组中筛选空元素: