PySpark:CumSum with Salting over Window w/ Skew

cgvd09ve  于 2024-01-06  发布在  Spark
关注(0)|答案(1)|浏览(138)

如何使用salting来执行累积求和窗口操作?虽然样本很小,但我的id列严重倾斜,我需要有效地对其执行此操作:

window_unsalted = Window.partitionBy("id").orderBy("timestamp")  

# exected value
df = df.withColumn("Expected", F.sum('value').over(window_unsalted))

字符串
然而,我想尝试加盐,因为在我的数据规模,我不能计算它。
考虑一下这个MWE,我如何使用加盐技术复制期望值20?

from pyspark.sql import functions as F  
from pyspark.sql.window import Window  

data = [  
    (7329, 1636617182, 1.0),  
    (7329, 1636142065, 1.0),  
    (7329, 1636142003, 1.0),  
    (7329, 1680400388, 1.0),  
    (7329, 1636142400, 1.0),  
    (7329, 1636397030, 1.0),  
    (7329, 1636142926, 1.0),  
    (7329, 1635970969, 1.0),  
    (7329, 1636122419, 1.0),  
    (7329, 1636142195, 1.0),  
    (7329, 1636142654, 1.0),  
    (7329, 1636142484, 1.0),  
    (7329, 1636119628, 1.0),  
    (7329, 1636404275, 1.0),  
    (7329, 1680827925, 1.0),  
    (7329, 1636413478, 1.0),  
    (7329, 1636143578, 1.0),  
    (7329, 1636413800, 1.0),  
    (7329, 1636124556, 1.0),  
    (7329, 1636143614, 1.0),  
    (7329, 1636617778, -1.0),  
    (7329, 1636142155, -1.0),  
    (7329, 1636142061, -1.0),  
    (7329, 1680400415, -1.0),  
    (7329, 1636142480, -1.0),  
    (7329, 1636400183, -1.0),  
    (7329, 1636143444, -1.0),  
    (7329, 1635977251, -1.0),  
    (7329, 1636122624, -1.0),  
    (7329, 1636142298, -1.0),  
    (7329, 1636142720, -1.0),  
    (7329, 1636142584, -1.0),  
    (7329, 1636122147, -1.0),  
    (7329, 1636413382, -1.0),  
    (7329, 1680827958, -1.0),  
    (7329, 1636413538, -1.0),  
    (7329, 1636143610, -1.0),  
    (7329, 1636414011, -1.0),  
    (7329, 1636141936, -1.0),  
    (7329, 1636146843, -1.0)  
]  
  
df = spark.createDataFrame(data, ["id", "timestamp", "value"])  
  
# Define the number of salt buckets  
num_buckets = 100  
  
# Add a salted_id column to the dataframe  
df = df.withColumn("salted_id", (F.concat(F.col("id"),   
                (F.rand(seed=42)*num_buckets).cast("int")).cast("string")))  
  
# Define a window partitioned by the salted_id, and ordered by timestamp  
window = Window.partitionBy("salted_id").orderBy("timestamp")  
  
# Add a cumulative sum column  
df = df.withColumn("cumulative_sum", F.sum("value").over(window))  
  
# Define a window partitioned by the original id, and ordered by timestamp  
window_unsalted = Window.partitionBy("id").orderBy("timestamp")  
  
# Compute the final cumulative sum by adding up the cumulative sums within each original id  
df = df.withColumn("final_cumulative_sum",   
                   F.sum("cumulative_sum").over(window_unsalted))  

# exected value
df = df.withColumn("Expected", F.sum('value').over(window_unsalted))

# incorrect trial
df.agg(F.sum('final_cumulative_sum')).show()

# expected value
df.agg(F.sum('Expected')).show()

zvms9eto

zvms9eto1#

从我所看到的,这里的主要问题是,时间戳必须保持有序的部分累积和是正确的,例如,如果序列是1,2,3,那么2不能进入不同的分区,而不是1和3。
我的建议是使用基于时间戳的salt值,以保持顺序。这不会完全消除偏斜,但您仍然可以在同一个id中分区:

df = spark.createDataFrame(data, ["id", "timestamp", "value"])

bucket_size = 10000  # the actual size will depend on timestamp distribution

# Add timestamp-based salt column to the dataframe
df = df.withColumn("salt", F.floor(F.col("timestamp") / F.lit(bucket_size)))

# Get partial cumulative sums
window_salted = Window.partitionBy("id", "salt").orderBy("timestamp")
df = df.withColumn("cumulative_sum", F.sum("value").over(window_salted))

# Get partial cumulative sums from previous windows
df2 = df.groupby("id", "salt").agg(F.sum("value").alias("cumulative_sum_last"))
window_full = Window.partitionBy("id").orderBy("salt")
df2 = df2.withColumn("previous_sum", F.lag("cumulative_sum_last", default=0).over(window_full))
df2 = df2.withColumn("previous_cumulative_sum", F.sum("previous_sum").over(window_full))

# Join previous partial cumulative sums with original data
df = df.join(df2, ["id", "salt"])  # maybe F.broadcast(df2) if it is small enough

# Increase each cumulative sum value by final value of the previous window
df = df.withColumn('final_cumulative_sum', F.col('cumulative_sum') + F.col('previous_cumulative_sum'))

# expected value
window_unsalted = Window.partitionBy("id").orderBy("timestamp")
df = df.withColumn("Expected", F.sum('value').over(window_unsalted))

# new calculation
df.agg(F.sum('final_cumulative_sum')).show()

# expected value
df.agg(F.sum('Expected')).show()

字符串

相关问题