pyspark 如何在Databricks中优化我的spark脚本

rvpgvaaj  于 2023-11-16  发布在  Spark
关注(0)|答案(2)|浏览(113)

在我的脚本中,我有一个有3列的表,分别是vehicleIdtripStartDateTimecorrelationId。表可以按vehicleId列分区。

+------------------+--------------------+------------------------------------+
|vehicleId         |tripStartDateTime   |correlationId                       |
+------------------+--------------------+------------------------------------+
|00045b1b-0ac9-4dce|2023-07-26T16:35:34Z|1f036bb8-cac4-43c1-b29e-a7646884fe2e|
|00045b1b-0ac9-4dce|2023-07-26T17:27:38Z|134b785e-e013-41b1-aabc-094a90b95482|
|00045b1b-0ac9-4dce|2023-07-26T18:04:16Z|51fb0e53-2938-431c-8825-7f461849dfe3|
|00045b1b-0ac9-4dce|2023-07-26T18:32:46Z|954a4f96-2c51-403b-9fd5-d07a7cdc35dd|
|00045b1b-0ac9-4dce|2023-07-26T18:40:18Z|811a1336-27f3-4e8c-99cc-22f5debe21a3|
|8eba-55a058fb4dd0f|2023-07-20T10:35:34Z|1f036bff-cac4-dddd-ddsa-a7646884fe2e|
|8eba-55a058fb4dd0f|2023-07-20T10:65:34Z|23226bff-cac4-dddd-ddsa-a7646884fe2e|
...

字符串
在每一个vehicleId(分区)里面,我想根据tripStartDateTime一行一行的处理,每一行都会传递给一个自定义函数,在自定义函数里面,有复杂的计算,每一行的结果都会保存到另一个表里面,下一行会使用前面几行的结果,所以每一行都要按顺序处理。

如何编写高效的脚本,保证每个分区可以并发处理(因为一辆车不影响其他车),但在一个分区内,记录会按顺序一条一条处理?

我目前的解决方案是准备一个UDF,将每一行传递给UDF,然后在for循环中调用UDF。但这很慢,尽管我在Databricks中使用了多节点集群。

for (row <- df.collect()) {
   processRowUDF(row)
}


谁能给予我一些关于如何优化它的建议?谢谢

dxxyhpgq

dxxyhpgq1#

您可以使用自定义的pandas udf类型PandasUDFType.GROUPED_AGG。它接受分组数据(在本例中按vehicleId分组)并返回列表列表,这些列表可以分解以插入到另一个表或创建另一个框架。下面是一个例子。

from pyspark import SQLContext
from pyspark.sql.functions import *
import pyspark.sql.functions as F
from pyspark.sql.types import *
from pyspark.sql.window import Window
from typing import Iterator, Tuple
import pandas as pd

sc = SparkContext('local')
sqlContext = SQLContext(sc)

data1 = [

["00045b1b-0ac9-4dce", "2023-07-26T16:35:34Z", "1f036bb8-cac4-43c1-b29e-a7646884fe2e"],
["00045b1b-0ac9-4dce", "2023-07-26T17:27:38Z", "134b785e-e013-41b1-aabc-094a90b95482"],
["00045b1b-0ac9-4dce", "2023-07-26T18:04:16Z", "51fb0e53-2938-431c-8825-7f461849dfe3"],
["00045b1b-0ac9-4dce", "2023-07-26T18:32:46Z", "954a4f96-2c51-403b-9fd5-d07a7cdc35dd"],
["00045b1b-0ac9-4dce", "2023-07-26T18:40:18Z", "811a1336-27f3-4e8c-99cc-22f5debe21a3"],
["8eba-55a058fb4dd0f", "2023-07-20T10:35:34Z", "1f036bff-cac4-dddd-ddsa-a7646884fe2e"],
["8eba-55a058fb4dd0f", "2023-07-20T10:55:34Z", "23226bff-cac4-dddd-ddsa-a7646884fe2e"],

]
columns1 =["vehicleId", "tripStartDateTime", "correlationId"]

df1 = sqlContext.createDataFrame(data=data1, schema=columns1)
df1 = df1.withColumn("tripStartDateTime", F.to_timestamp(F.col("tripStartDateTime")))
print("Given dataframe")
df1.show(n=100, truncate=False)

print("schema here")
print(df1.schema)


@pandas_udf(ArrayType(ArrayType(StringType())), PandasUDFType.GROUPED_AGG)
def custom_sum_udf(col1_series: pd.Series, col2_series: pd.Series, col3_series: pd.Series) -> ArrayType(ArrayType(StringType())):
    concat_df = pd.concat([col1_series, col2_series, col3_series], axis=1)
    concat_df.columns = columns1 # declared before
    print("Process the vehicle Id")  # Do complex processing since concat_df is now a pandas dataframe.
    print(concat_df)

    max_column = concat_df["tripStartDateTime"].max()
    min_column = concat_df["tripStartDateTime"].min()

    print("max_column", max_column)
    print("min_column", min_column)
    all_result = [ [concat_df.iloc[0,0]], [str(max_column)], [str(min_column)]]

    return all_result

df_new = df1.groupby(F.col("vehicleId")).agg(custom_sum_udf(F.col("vehicleId"), F.col("tripStartDateTime"), F.col("correlationId")).alias("reduced_columns")).cache()
print("Printing the column sum, max, min")
df_new.show(n=100, truncate=False)

df_new_sep = df_new.withColumn("id_recieved", F.col("reduced_columns").getItem(0))
df_new_sep = df_new_sep.withColumn("max_over_timestamp", F.col("reduced_columns").getItem(1))
df_new_sep = df_new_sep.withColumn("min_over_timestamp", F.col("reduced_columns").getItem(2)).drop(F.col("reduced_columns"))
print("Printing the column max, min")
df_new_sep.show(n=100, truncate=False)

字符串
输出量:

Given dataframe
+------------------+-------------------+------------------------------------+
|vehicleId         |tripStartDateTime  |correlationId                       |
+------------------+-------------------+------------------------------------+
|00045b1b-0ac9-4dce|2023-07-26 22:05:34|1f036bb8-cac4-43c1-b29e-a7646884fe2e|
|00045b1b-0ac9-4dce|2023-07-26 22:57:38|134b785e-e013-41b1-aabc-094a90b95482|
|00045b1b-0ac9-4dce|2023-07-26 23:34:16|51fb0e53-2938-431c-8825-7f461849dfe3|
|00045b1b-0ac9-4dce|2023-07-27 00:02:46|954a4f96-2c51-403b-9fd5-d07a7cdc35dd|
|00045b1b-0ac9-4dce|2023-07-27 00:10:18|811a1336-27f3-4e8c-99cc-22f5debe21a3|
|8eba-55a058fb4dd0f|2023-07-20 16:05:34|1f036bff-cac4-dddd-ddsa-a7646884fe2e|
|8eba-55a058fb4dd0f|2023-07-20 16:25:34|23226bff-cac4-dddd-ddsa-a7646884fe2e|
+------------------+-------------------+------------------------------------+

schema here
StructType([StructField('vehicleId', StringType(), True), StructField('tripStartDateTime', TimestampType(), True), StructField('correlationId', StringType(), True)])
    
    
+------------------+--------------------------------------------------------------------+
|vehicleId         |reduced_columns                                                     |
+------------------+--------------------------------------------------------------------+
|00045b1b-0ac9-4dce|[[00045b1b-0ac9-4dce], [2023-07-27 00:10:18], [2023-07-26 22:05:34]]|
|8eba-55a058fb4dd0f|[[8eba-55a058fb4dd0f], [2023-07-20 16:25:34], [2023-07-20 16:05:34]]|
+------------------+--------------------------------------------------------------------+

Printing the column max, min
+------------------+--------------------+---------------------+---------------------+
|vehicleId         |id_recieved         |max_over_timestamp   |min_over_timestamp   |
+------------------+--------------------+---------------------+---------------------+
|00045b1b-0ac9-4dce|[00045b1b-0ac9-4dce]|[2023-07-27 00:10:18]|[2023-07-26 22:05:34]|
|8eba-55a058fb4dd0f|[8eba-55a058fb4dd0f]|[2023-07-20 16:25:34]|[2023-07-20 16:05:34]|
+------------------+--------------------+---------------------+---------------------+

daolsyd0

daolsyd02#

类似的方法使用scala spark和Dataset。有关于如何编写自定义用户定义聚合函数的指南。
https://spark.apache.org/docs/latest/sql-ref-functions-udf-aggregate.html

import org.apache.spark.SparkConf
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.sql._
import org.apache.spark.sql.functions.{abs, col, lit}

import scala.collection.mutable.{ArrayBuffer, ListBuffer}
import scala.math.{log => mathLog}

object VehicleDemo {

  case class VehicleInfo(
                                    vehicleId: String = null,
                                    tripStartDateTime: String = null,
                                    correlationId: String = null
                                  )

  def main(args: Array[String]): Unit ={

    val sparkConfig = new SparkConf().setMaster("local[*]")

    val spark = SparkSession
      .builder()
      .appName("Vehicle Demo")
      .config(sparkConfig)
      .getOrCreate()
    spark.sparkContext.setLogLevel("WARN")

    val conf: SparkConf = spark.sparkContext.getConf

    import spark.implicits._

    val vehicleInfoList = ArrayBuffer[VehicleInfo]()
    vehicleInfoList.append(VehicleInfo("00045b1b-0ac9-4dce", "2023-07-26T16:35:34Z", "1f036bb8-cac4-43c1-b29e-a7646884fe2e"))
    vehicleInfoList.append(VehicleInfo("00045b1b-0ac9-4dce", "2023-07-26T17:27:38Z", "134b785e-e013-41b1-aabc-094a90b95482"))
    vehicleInfoList.append(VehicleInfo("00045b1b-0ac9-4dce", "2023-07-26T18:04:16Z", "51fb0e53-2938-431c-8825-7f461849dfe3"))
    vehicleInfoList.append(VehicleInfo("00045b1b-0ac9-4dce", "2023-07-26T18:32:46Z", "954a4f96-2c51-403b-9fd5-d07a7cdc35dd"))
    vehicleInfoList.append(VehicleInfo("00045b1b-0ac9-4dce", "2023-07-26T18:40:18Z", "811a1336-27f3-4e8c-99cc-22f5debe21a3"))
    vehicleInfoList.append(VehicleInfo("8eba-55a058fb4dd0f", "2023-07-20T10:35:34Z", "1f036bff-cac4-dddd-ddsa-a7646884fe2e"))
    vehicleInfoList.append(VehicleInfo("8eba-55a058fb4dd0f", "2023-07-20T10:65:34Z", "23226bff-cac4-dddd-ddsa-a7646884fe2e"))

    val vehicleInfoDataset = spark.sqlContext.createDataset(vehicleInfoList)

    val myList = vehicleInfoList.toList

    println("")
    for(ele <- myList){

      println(ele)
    }

    // the return value of custom_agg_function could be array of Case Class which can be exploded and then used to create another dataframe.
    val groupedDataset = vehicleInfoDataset.groupBy("vehicleId").agg( custom_agg_function())  

    spark.sparkContext.stop()

  }
}

字符串

相关问题