Apache Spark 遍历选择列,检查这些选择列中是否存在特定值,并使用具有该值的列名创建新表

fquxozlt  于 2022-11-16  发布在  Apache
关注(0)|答案(2)|浏览(126)

如果我有下表

+-----+-----+---+-----+-----+-----+-----+-----+-----+-----+
|    a|     b|    id|m2000|m2001|m2002|m2003|m2004|m2005
+-----+-----+---+-----+-----+-----+-----+-----+-----+-----+
|a    |world|      1|    0|    0|    1|    0|    0|    1|   
+-----+-----+---+-----+-----+-----+-----+-----+-----+-----+

如何创建一个如下所示的新 Dataframe ,检查列m2000到m2014,并查看这些字段是否为1。然后创建下表,其中10/10为静态。使用2002和2005,因为m2000和m2014之间只有2列,其中1在上表中。

|id | year      | yearend   |
|1  | 10/10/2002| 12/12/2005|
|1  | 10/10/2002| 12/12/2005|

创建第一个 Dataframe 的代码

from pyspark.shell import spark
from pyspark.sql.types import StructType, StructField, StringType, IntegerType

data2 = [("a", "world", "1", 0, 0, 1,0,0,1),
         ]

schema = StructType([ \
    StructField("a", StringType(), True), \
    StructField("b", StringType(), True), \
    StructField("id", StringType(), True), \
    StructField("m2000", IntegerType(), True), \
    StructField("m2001", IntegerType(), True), \
    StructField("m2002", IntegerType(), True), \
    StructField("m2003", IntegerType(), True), \
    StructField("m2004", IntegerType(), True), \
    StructField("m2005", IntegerType(), True), \
    ])

df = spark.createDataFrame(data=data2, schema=schema)
df.printSchema()
df.show(truncate=False)
4urapxun

4urapxun1#

假设 Dataframe 具有更完整的场景,其中存在年份不为“1”的行和具有更多“1”的行:

from pyspark.shell import spark
from pyspark.sql.types import StructType, StructField, StringType, IntegerType

data2 = [("a", "world", "1", 0, 0, 1,0,0,1),
         ("b", "world", "2", 0, 1, 0,1,0,1),
         ("c", "world", "3", 0, 0, 0,0,0,0)
         ]

schema = StructType([ \
    StructField("a", StringType(), True), \
    StructField("b", StringType(), True), \
    StructField("id", StringType(), True), \
    StructField("m2000", IntegerType(), True), \
    StructField("m2001", IntegerType(), True), \
    StructField("m2002", IntegerType(), True), \
    StructField("m2003", IntegerType(), True), \
    StructField("m2004", IntegerType(), True), \
    StructField("m2005", IntegerType(), True), \
    ])

df = spark.createDataFrame(data=data2, schema=schema)
df.printSchema()
df.show(truncate=False)

| | 一种|B|标识符|M2000计算机|2001年中期|2002年中期|2003年中期|2004年中期|2005年中期|
| - -|- -|- -|- -|- -|- -|- -|- -|- -|- -|
| 第0页|一种|全世界|一个|第0页|第0页|一个|第0页|第0页|一个|
| 一个|B|全世界|2个|第0页|一个|第0页|一个|第0页|一个|
| 2个|C语言|全世界|三个|第0页|第0页|第0页|第0页|第0页|第0页|
为了方便起见,我把你的 Dataframe 传给Pandas,但我会使用简单的迭代结构,你可以把它集成到spark中。

pandas_df = df.toPandas()

我们检索不包括前3列的年份列表:

years = list(pandas_df.columns)[3:]

最后,生成所需 Dataframe 所需的代码如下(行内注解):

tmp_df_data_list = []

# iterate over rows of df
for _, row in pandas_df.iterrows():
    flagged_years=[]
    
    # for each year check if col value is 1
    for y in years:
        if row[y]:  # if is 1, append col name
            flagged_years.append(y)
            
    if len(flagged_years) >= 2:
        # get first occurence as 'year' and last as 'yearend' by removing the first letter
        min_year = flagged_years[0][1:]
        max_year = flagged_years[-1][1:]

        tmp_df_data_list.append([row.id, '10/10/'+min_year, '12/12/'+max_year])

res_df = pd.DataFrame(tmp_df_data_list, columns=['id', 'year', 'yearend'])

输出将为:
| | 标识符|年份|年底|
| - -|- -|- -|- -|
| 第0页|一个|二零零二年十月十日|二OO五年十二月十二日|
| 一个|2个|二零零一年十月十日|二OO五年十二月十二日|

i2byvkas

i2byvkas2#

我们可以使用pyspark本机函数创建一个值为1的列名数组。然后,可以使用该数组获取年份的minmax以及"10/10/"concat
下面是一个示例

data_ls = [
    ("a", "world", "1", 0, 0, 1,0,0,1),
    ("b", "world", "2", 0, 1, 0,1,0,1),
    ("c", "world", "3", 0, 0, 0,0,0,0)
]

data_sdf = spark.sparkContext.parallelize(data_ls). \
    toDF(['a', 'b', 'id', 'm2000', 'm2001', 'm2002', 'm2003', 'm2004', 'm2005'])

# +---+-----+---+-----+-----+-----+-----+-----+-----+
# |  a|    b| id|m2000|m2001|m2002|m2003|m2004|m2005|
# +---+-----+---+-----+-----+-----+-----+-----+-----+
# |  a|world|  1|    0|    0|    1|    0|    0|    1|
# |  b|world|  2|    0|    1|    0|    1|    0|    1|
# |  c|world|  3|    0|    0|    0|    0|    0|    0|
# +---+-----+---+-----+-----+-----+-----+-----+-----+

yearcols = [k for k in data_sdf.columns if k.startswith('m20')]

data_sdf. \
    withColumn('yearcol_structs', 
               func.array(*[func.struct(func.lit(int(c[-4:])).alias('year'), func.col(c).alias('value')) 
                            for c in yearcols]
                          )
               ). \
    withColumn('yearcol_1s', 
               func.expr('transform(filter(yearcol_structs, x -> x.value = 1), f -> f.year)')
               ). \
    filter(func.size('yearcol_1s') >= 1). \
    withColumn('year_start', func.concat(func.lit('10/10/'), func.array_min('yearcol_1s'))). \
    withColumn('year_end', func.concat(func.lit('10/10/'), func.array_max('yearcol_1s'))). \
    show(truncate=False)

# +---+-----+---+-----+-----+-----+-----+-----+-----+------------------------------------------------------------------+------------------+----------+----------+
# |a  |b    |id |m2000|m2001|m2002|m2003|m2004|m2005|yearcol_structs                                                   |yearcol_1s        |year_start|year_end  |
# +---+-----+---+-----+-----+-----+-----+-----+-----+------------------------------------------------------------------+------------------+----------+----------+
# |a  |world|1  |0    |0    |1    |0    |0    |1    |[{2000, 0}, {2001, 0}, {2002, 1}, {2003, 0}, {2004, 0}, {2005, 1}]|[2002, 2005]      |10/10/2002|10/10/2005|
# |b  |world|2  |0    |1    |0    |1    |0    |1    |[{2000, 0}, {2001, 1}, {2002, 0}, {2003, 1}, {2004, 0}, {2005, 1}]|[2001, 2003, 2005]|10/10/2001|10/10/2005|
# +---+-----+---+-----+-----+-----+-----+-----+-----+------------------------------------------------------------------+------------------+----------+----------+

相关问题