在迭代过程中使用Spark

qoefvg9y  于 12个月前  发布在  Apache
关注(0)|答案(1)|浏览(99)

我是Spark的新手,正在尝试了解建立并行运行的迭代过程的最佳方法。
例如,如果我有一个DataFrame和Collatz Conjecture(也就是说,给定一个函数,如果n是奇数则返回3n+1,如果n是偶数则返回n/2,如果我们重复运行这个函数,它最终总是返回1),我想确定返回1需要多少次迭代--我可以很容易地将其迭代编写为(因为它是这样定义的),如

nums = [(x, 0) for x in range(1, 5)]
schema = ['num', 'iters']
df = spark.createDataFrame(data=nums, schema=schema)
while True:
    checker = df.filter(F.col('num') != 1)
    if (checker.count() == 0):
        break

    df = df.withColumn('num', 
                F.when(
                    F.col('num') == 1,
                    F.col('num')
                )
                .otherwise(
                    F.when(F.col('num') % F.lit(2) != 0, 
                        F.col('num') * F.lit(3) + F.lit(1)
                    )
                    .otherwise(
                        F.col('num') / F.lit(2)
                    )
                )
    )
    df = df.withColumn('iters',
                F.when(F.col('num') != 1.0, F.col('iters') + F.lit(1))
                .otherwise(F.col('iters'))
    )

df.show()

字符串
注意:我知道你可以递归地做这些事情,这样做会更好。我只是把它作为Spark中迭代过程的一个例子。
但这真的很难看,而且没有在Spark上优化。我知道循环是Spark中的一个反模式,但我不知道还能用什么方法来做到这一点。

k2fxgqgv

k2fxgqgv1#

你是对的,使用显式循环可能不是在Spark中实现迭代过程的非常有效的方法。
我认为您可以使用pyspark.sql.functionsudf以更实用的方式定义转换。
假设我理解了你的用例,这个想法是重复地应用转换,直到DataFrame中的所有数字都达到值1。这里有一个方法:

from pyspark.sql import SparkSession
from pyspark.sql import functions as F

# Initialize Spark session
spark = SparkSession.builder.appName("CollatzConjecture").getOrCreate()

# Define a UDF (User-Defined Function) for the Collatz transformation
def collatz_transform(num):
    return num * 3 + 1 if num % 2 != 0 else num // 2

# Register the UDF
collatz_udf = F.udf(collatz_transform)

# Create the initial DataFrame
nums = [(x, 0) for x in range(1, 5)]
schema = ['num', 'iters']
df = spark.createDataFrame(data=nums, schema=schema)

# Define a function to apply the Collatz transformation until all numbers reach 1
def collatz_iteration(df):
    return df.withColumn('num', F.when(F.col('num') == 1, F.col('num')).otherwise(collatz_udf(F.col('num')))) \
             .withColumn('iters', F.col('iters') + F.lit(1))

# Iterate until all numbers reach 1
while df.filter(F.col('num') != 1).count() > 0:
    df = collatz_iteration(df)

df.show()

字符串
告诉我这是否符合你的需求。

相关问题