pyspark窗口,在组之间聚合结果

gr8qqesn  于 2021-05-19  发布在  Spark
关注(0)|答案(1)|浏览(450)

假设我有一个Dataframe,其中包含不同用户通过不同协议发出的请求和记录的度量值:

+---+-----+--------+------------+
| ts| user|protocol|metric_value|
+---+-----+--------+------------+
|  0|user1|     tcp|         197|
|  1|user1|     udp|         155|
|  2|user1|     tcp|         347|
|  3|user1|     tcp|         117|
|  4|user1|     tcp|         230|
|  5|user1|     udp|         225|
|  6|user1|     udp|         297|
|  7|user1|     tcp|         790|
|  8|user1|     udp|         216|
|  9|user1|     udp|         200|
+---+-----+--------+------------+

我需要为当前用户的每个协议添加另一个列,其中有最后记录的平均度量值(在当前时间戳之前,并且不早于当前的\uts-4)。所以,算法是这样的:
对于每行x:
查找row.user==x.user和row.ts<x.ts的所有行
从这些行中提取每个协议的最新度量值(如果相应的记录早于x.ts-4,则抛出它)
计算这些度量值的平均值
将计算的平均值追加到新列中的行
预期结果如下:

+---+-----+--------+------------+-------+
| ts| user|protocol|metric_value|avg_val|
+---+-----+--------+------------+-------+
|  0|user1|     tcp|         197|   null| // no data for user1
|  1|user1|     udp|         155|    197| // only tcp value available
|  2|user1|     tcp|         347|    176| // (197 + 155) / 2
|  3|user1|     tcp|         117|    251| // (347 + 155) / 2
|  4|user1|     tcp|         230|    136| // (117 + 155) / 2
|  5|user2|     udp|         225|   null| // because no data for user2
|  6|user1|     udp|         297|    230| // because record with ts==1 is too old now
|  7|user1|     tcp|         790|  263.5| // (297 + 230) / 2
|  8|user1|     udp|         216|  543.5| // (297 + 790) / 2
|  9|user1|     udp|         200|    503| // (216 + 790) / 2
+---+-----+--------+------------+-------+

请注意,表中可能有任意数量的协议和用户。
如何实现?
我试过使用窗口函数、lag(1)和按协议分区,但是聚合函数只计算单个分区的平均值,而不计算不同分区的结果。最接近的结果是sql请求使用协议分区上的行数,但我无法在那里传播row.ts<x.ts条件。

baubqpgj

baubqpgj1#

这是基于scala的解决方案,您可以将逻辑转换为python/pyspark
样本数据:

val df = Seq((0,"user1","tcp",197),(1,"user1","udp",155),(2,"user1","tcp",347),(3,"user1","tcp",117),(4,"user1","tcp",230),(5,"user2","udp",225),(6,"user1","udp",297),(7,"user1","tcp",790),(8,"user1","udp",216),(9,"user1","udp",200))
.toDF("ts","user","protocol","metric_value")

对于每一行,获取所有行 (protocol,metric_value) 为了 current_row.ts -4 在列表中。

val winspec = Window.partitionBy("user").orderBy("ts").rangeBetween(Window.currentRow - 4, Window.currentRow-1)
val df2 = df.withColumn("recent_list", collect_list(struct($"protocol", $"metric_value")).over(winspec))

df2.orderBy("ts").show(false)
/*

+---+-----+--------+------------+------------------------------------------------+
|ts |user |protocol|metric_value|recent_list                                       |
+---+-----+--------+------------+------------------------------------------------+
|0  |user1|tcp     |197         |[]                                              |
|1  |user1|udp     |155         |[[tcp, 197]]                                    |
|2  |user1|tcp     |347         |[[tcp, 197], [udp, 155]]                        |
|3  |user1|tcp     |117         |[[tcp, 197], [udp, 155], [tcp, 347]]            |
|4  |user1|tcp     |230         |[[tcp, 197], [udp, 155], [tcp, 347], [tcp, 117]]|
|5  |user2|udp     |225         |[]                                              |
|6  |user1|udp     |297         |[[tcp, 347], [tcp, 117], [tcp, 230]]            |
|7  |user1|tcp     |790         |[[tcp, 117], [tcp, 230], [udp, 297]]            |
|8  |user1|udp     |216         |[[tcp, 230], [udp, 297], [tcp, 790]]            |
|9  |user1|udp     |200         |[[udp, 297], [tcp, 790], [udp, 216]]            |
+---+-----+--------+------------+------------------------------------------------+

现在,您在一行中获得了所有必需的信息。您可以编写一个udf来应用获取最新协议类型和平均值的逻辑。

def getAverageValueForUniqRecents(list : Array[StructType]): Double = {
  // you logic goes here. 
  // Loop through your array in REVERSE ORDER
  // maintain a set to check if protocol already visited then skip, otherwise SUM
  //Finally average
}

相关问题