如何使用pyspark将函数应用于pyspark Dataframe 中的列?

4nkexdtk  于 2023-10-15  发布在  Spark
关注(0)|答案(1)|浏览(151)

我在python中有这个,我想转换为pyspark。下面是Python代码。

import pandas as pd

def mapping_func(df, subject_col):
    return {item:i for i, item in enumerate(df[subject_col].unique())}
    
def func(x, mapping, prefix = None):
    if prefix is not None:
        return str(prefix) + str(mapping[x])
    else:
        return str(mapping[x])

data = {'state': ["Alabama", "California", "Maine", "Ohio", "Arizona", "Montana"]}
df1 = pd.DataFrame(data)
        
df1['state_code'] = df1['state'].apply(func, args=(mapping_func(df1, "state"), "S"))

print(df1)

在这里,我只是将州代码分配给美国各州。然后,我应用函数func将代码(带前缀)分配给包含美国州的列。这是一个相当简单的Python代码,但我无法将其转换为pyspark。输出应该如下所示-output image
我试过这个-

import pandas as pd
from pyspark.sql.functions import udf
from pyspark.sql.types import StringType
import json

def mapping_func(df, subject_col):
    unique_values = df.select(subject_col).distinct().rdd.flatMap(lambda x: x).collect()
    mapping = {item: i for i, item in enumerate(unique_values)}
    return json.dumps(mapping)

def func(x, mapping_str, prefix=None):
    mapping = json.loads(mapping_str)
    if prefix is not None:
        return str(prefix) + str(mapping[x])
    else:
        return str(mapping[x])

# Sample data
data = {'state': ["Alabama", "California", "Maine", "Ohio", "Arizona", "Montana"]}
df1 = spark.createDataFrame(pd.DataFrame(data))
df1.show()

mapping = mapping_func(df1, "state")
print(mapping)
df1 = df1.withColumn("state_code", func(df1["state"], mapping, "S"))
df1.show()

但这是givingType error - TypeError:不可哈希类型:'列'
基本上,函数func中的变量x是列对象,这就是问题所在。

m1m5dgzv

m1m5dgzv1#

您可以将monotonically_increasing_id沿着与row_number一起使用,而不是枚举。下面是一个示例代码片段

from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql.window import Window
import pandas as pd

# Initialize SparkSession
spark = SparkSession.builder.appName("State_Prefix").getOrCreate()

def get_state_codes(input_df, prefix=None):
    if prefix is not None:
        output_df = input_df.withColumn('state_code', F.concat(F.lit(prefix), F.col('index')))
    else:
        output_df = input_df.withColumn('state_code', F.col('index'))

    return output_df

data = {'state': ["Alabama", "California", "Maine", "Ohio", "Arizona", "Montana"]}
df1 = spark.createDataFrame(pd.DataFrame(data))

df1 = df1.withColumn(
    "index",
    F.row_number().over(Window.orderBy(F.monotonically_increasing_id()))-1
)
# df1.show()

df1 = get_state_codes(df1, "S")
df1 = df1.drop('index')

df1.show()

相关问题