在pyspark对象框架中的数组上迭代,并基于与数组中的值同名的列创建新列

bybem2ql  于 2024-01-06  发布在  Spark
关注(0)|答案(1)|浏览(209)

我有一个这样格式的表格:
| 名称|水果|苹果|香蕉|橙子|
| --|--|--|--|--|
| 爱丽丝|[“苹果”、“香蕉”、“橙子”]| 5 | 8 | 3 |
| 鲍勃|[“苹果”]| 2 | 9 | 1 |
我想创建一个新列,其中包含以下格式的JSON包,其中键是数组的元素,值是列名称的结果值:
| 名称|水果|苹果|香蕉|橙子|新科尔|
| --|--|--|--|--|--|
| 爱丽丝|[“苹果”、“香蕉”、“橙子”]| 5 | 8 | 3 |{“苹果”:5,“香蕉”:8,“橙子”:3}|
| 鲍勃|[“苹果”]| 2 | 9 | 1 |{“apple”:2}|
我假设有一个UDF,但是我不能得到正确的语法。
这是我对代码的理解:

  1. from pyspark.sql.functions import udf, col
  2. from pyspark.sql.types import MapType, StringType
  3. # Create a Spark session
  4. spark = SparkSession.builder.appName("example").getOrCreate()
  5. # Sample data
  6. data = [("Alice", ["apple", "banana", "orange"], 5, 8, 3),
  7. ("Bob", ["apple"], 2, 9, 1)]
  8. # Define the schema
  9. schema = ["name", "fruits", "apple", "banana", "orange"]
  10. # Create a DataFrame
  11. df = spark.createDataFrame(data, schema=schema)
  12. # Show the initial DataFrame
  13. print("Initial DataFrame:")
  14. display(df)
  15. # Define a UDF to create a dictionary
  16. @udf(MapType(StringType(), StringType()))
  17. def json_map(fruits):
  18. result = {}
  19. for i in fruits:
  20. result[i] = col(i)
  21. return result
  22. # Apply the UDF to the 'fruits' column
  23. new_df = df.withColumn('test', json_map(col('fruits')))
  24. # Display the updated DataFrame
  25. display(new_df)

字符串

qyuhtwio

qyuhtwio1#

你可以使用Abdennacer在他的回答中分享的arrays_zip方法,但前提是数组元素应该与你的列对齐,这可能并不总是如此。
另一种方法是为列创建Map数组,并过滤该数组以仅保留fruits数组中可用的键的键值对。
这里有一个例子

  1. # i've changed the input slightly to rearrange the fruits array
  2. # +-----+-----------------------+-----+------+------+
  3. # |name |fruits |apple|banana|orange|
  4. # +-----+-----------------------+-----+------+------+
  5. # |Alice|[orange, banana, apple]|5 |8 |3 |
  6. # |Bob |[apple] |2 |9 |1 |
  7. # +-----+-----------------------+-----+------+------+
  8. data_sdf. \
  9. withColumn('fruitcols_arr',
  10. func.array(*[func.create_map([func.lit(c), func.col(c)]) for c in data_sdf.drop('name', 'fruits').columns])
  11. ). \
  12. withColumn('fruitcols_arr',
  13. func.expr('filter(fruitcols_arr, x -> array_contains(fruits, map_keys(x)[0]))')
  14. ). \
  15. withColumn('new_col',
  16. func.aggregate(func.expr('slice(fruitcols_arr, 2, size(fruitcols_arr))'),
  17. func.col('fruitcols_arr')[0],
  18. lambda x, y: func.map_concat(x, y)
  19. )
  20. ). \
  21. drop('fruitcols_arr'). \
  22. show(truncate=False)
  23. # +-----+-----------------------+-----+------+------+--------------------------------------+
  24. # |name |fruits |apple|banana|orange|new_col |
  25. # +-----+-----------------------+-----+------+------+--------------------------------------+
  26. # |Alice|[orange, banana, apple]|5 |8 |3 |{apple -> 5, banana -> 8, orange -> 3}|
  27. # |Bob |[apple] |2 |9 |1 |{apple -> 2} |
  28. # +-----+-----------------------+-----+------+------+--------------------------------------+

字符串
第一个fruitcols_arr创建一个Map数组(column_name -> column_value)使用每个单独的fruit列.第二个过滤器基于fruits列数组元素的数组.这是基于每个fruits元素的Map的最终数组. new_col是通过使用aggregate高阶函数与map_concat创建的,连接最终Map数组中的所有单个Map。

展开查看全部

相关问题