如何避免在pyspark中使用collect()进行自定义的迭代DataFrame转换

gr8qqesn  于 12个月前  发布在  Spark
关注(0)|答案(1)|浏览(117)

我正在努力用pyspark高效地实现一个数据转换,目标是将一个十六进制字符串分解成几个位于不同列中的可理解的数据元素(“子信号”)。
输入DataFrame包含由十六进制字符串组成的StringType()data_element_value
x1c 0d1x的数据
每个data_element_value字符串编码几个 * 子信号 *,如下图所示。



具体 * 如何 * 分解(认为“索引”)和转换每个子信号是由一个查找表(“解码 Dataframe ”)给出的,如下所示:



decode DataFrame中的每一行都告诉我们如何分解和转换十六进制字符串。(回想一下,每个十六进制字符是一个 * 半字节 *,两个半字节代表一个字节),Size列告诉我们要抓取多少字节--这就是分解。当每个子信号以这种方式分解时,我们将十六进制子字符串转换为一个以10为基数的整数,并对SlopeOffset列应用一个线性Map--这就是转换。完全转换后的DataFrame应该如下所示:


我所尝试的

我将split十六进制字符串转换为ArrayType(StringType())类型的新列data_element_value_split。该列提供了一种简单的方法来索引十六进制 * 数组 * 的每个子信号。然后我collect()输入并解码DataFrames并逐行扫描,为每个子信号索引十六进制数组并应用线性Map(在下面的代码中提供了最小的工作示例以容易地再现这一点)。
我的理解是,调用collect()会将所有数据处理限制在驱动节点,这会导致非分布式的,低效的数据转换。由于需要不断转换的数据量很大,这似乎是一个坏主意。我认为可以将collect()用于较小的解码帧,并将broadcast用于工人,但我认为collect()较大的输入DataFrame是不合适的。有没有一种方法可以以分布式的、有效的方式实现这种转换?也许使用用户定义的函数?任何解决方案都有帮助,但使用Spark RDD API不是首选,因为它在使用共享访问模式的Databricks Unity Catalog集群上不受支持。

我的(低效)解决方案

# Create sample input DataFrame
from pyspark.sql.functions import split
from pyspark.sql.types import (
    FloatType,
    IntegerType
    StringType,
    StructField,
    StructType
)
input_data = [
    ("encoded_signal_1", "17080904792000"),
    ("encoded_signal_2", "170809041e2401"),
]
input_schema = StructType(
    [
        StructField("data_element", StringType()),
        StructField("data_element_value", StringType()),
    ]
)
df_input = spark.createDataFrame(input_data, schema=input_schema)

# Create sample decode DataFrame
decode_data = [
    ("sub_signal_1", 1, 0, 1.0, 2000,),
    ("sub_signal_2", 1, 1, 1.0, 0,),
    ("sub_signal_3", 1, 2, 1.0, 0,),
    ("sub_signal_4", 4, 3, 0.015625, 0,),
]
decode_schema = StructType(
    [
        StructField("SignalName", StringType()),
        StructField("Size", IntegerType()),
        StructField("Beginning", IntegerType()),
        StructField("Slope", FloatType()),
        StructField("Offset", IntegerType()),
    ]
)
df_decode = spark.createDataFrame(decode_data, schema=decode_schema)

# Transform using collect()
df_input_split = df_input.withColumn("data_element_value_split", split("data_element_value", r"(?<=\G..)"))  # decompose hex into bytes
decode_rows = df_decode.collect()
input_rows = df_input_split.collect()
input_rows_decoded = []
for input_row in input_rows:

    input_row_dict = input_row.asDict()
    for decode_row in decode_rows:

        start_byte = decode_row.Beginning
        end_byte = start_byte + decode_row.Size
        sub_signal = "".join(input_row.data_element_value_split[start_byte:end_byte])
        sub_signal_decoded = int(sub_signal, 16)*decode_row.Slope + decode_row.Offset
        input_row_dict[decode_row.SignalName] = sub_signal_decoded
    input_rows_decoded.append(input_row_dict)
df_final = spark.createDataFrame(input_rows_decoded)
display(df_final)

字符串

u4vypkhs

u4vypkhs1#

假设df_decodedf_input小得多,我们可以通过迭代df_decode为每个子信号创建Column object

import numpy as np
from pyspark.sql import functions as F

cols = [
    (F.conv(
      F.substring(F.col('data_element_value'),  # take the substring of the data
        int(col[2])*2+1,                        # from start
        int(col[1])*2),                         # with length
      16, 10).cast('double')                    # use conv to convert from hex -> dec
       * float(col[3])                          # multiply with slope
       + float(col[4]))                         # and add the offset
       .alias(f"sub_signal_{(i+1)}")            # rename column
       for i, col in enumerate(np.array(df_decode.collect()))]
cols = df_input.columns + cols                  # add the original columns

df_input.select(cols).show()

字符串
输出量:

+----------------+------------------+------------+------------+------------+--------------+
|    data_element|data_element_value|sub_signal_1|sub_signal_2|sub_signal_3|  sub_signal_4|
+----------------+------------------+------------+------------+------------+--------------+
|encoded_signal_1|    17080904792000|      2023.0|         8.0|         9.0|     1172608.0|
|encoded_signal_2|    170809041e2401|      2023.0|         8.0|         9.0|1079440.015625|
+----------------+------------------+------------+------------+------------+--------------+

相关问题