pyspark Spark Array列-查找两个值之间的最大间隔

qzwqbdag  于 2023-10-15  发布在  Spark
关注(0)|答案(2)|浏览(100)

我有一个Scala Spark* 框架,其模式为schema

root
     |-- passengerId: string (nullable = true)
     |-- travelHist: array (nullable = true)
     |    |-- element: integer (containsNull = true)

我想遍历数组元素,找到1和2之间0值出现的最大次数
| 企业ID|旅行历史|
| --|--|
| 1 |1,0,0,0,0,2,1,0,0,0,0,0,0,0,2,1,0|
| 2 |0,0,0,0,0,0,0,0,2,1,0,0,0,0,0,0,0,0|
| 3 |0,0,0,2,1,0,2,1|
上述记录的输出应如下所示:
| 企业ID| maxStreak|
| --|--|
| 1 | 7 |
| 2 | 3 |
| 3 | 1 |
假设数组中元素的数量不超过50个值,那么找到这样一个区间的最**有效的方法是什么?

yqkkidmi

yqkkidmi1#

让我们做一些模式匹配

df1 = (
    df
    .withColumn('matches', F.expr("array_join(travelHist, '')"))
    .withColumn('matches', F.expr("regexp_extract_all(matches, '1(0+)2', 1)"))
    .withColumn('matches', F.expr("transform(matches, x -> length(x))"))
    .withColumn('maxStreak', F.expr("array_max(matches)"))
)
df1.show()
+-----------+--------------------+-------+---------+
|passengerID|          travelHist|matches|maxStreak|
+-----------+--------------------+-------+---------+
|          1|[1, 0, 0, 0, 0, 2...| [4, 7]|        7|
|          2|[0, 0, 0, 0, 0, 0...|    [3]|        3|
|          3|[0, 0, 0, 2, 1, 0...|    [1]|        1|
+-----------+--------------------+-------+---------+
sz81bmfz

sz81bmfz2#

这里有一个在pyspark中使用scala UDF的解决方案。您可以在以下存储库中找到pyspark脚本中使用的UDF和release jar的代码。
https://github.com/dineshdharme/pyspark-native-udfs
scala UDF的代码如下。

package com.help.udf

import org.apache.spark.sql.api.java.UDF1

import scala.collection.mutable
import util.control.Breaks._
import scala.reflect.runtime.currentMirror
import scala.tools.reflect.ToolBox

class CountZeros extends UDF1[Array[Int], Int] {

  override def call(given_array: Array[Int]): Int = {

    //println("Printing all element")
    //given_array.foreach(ele => print (ele + ",  "))
    //println("adding the debug printing ")
    var maxCount = -1

    var runningCount = -1
    var insideLoop = false

    for( ele <- given_array ){

        if (ele == 1) {
          // initialize count to 0
          runningCount = 0
          insideLoop = true

        }
        if (ele == 0 && insideLoop) {
          runningCount += 1

        }
        if (ele == 2 && insideLoop) {
          insideLoop = false
          if (maxCount == -1) {
            maxCount = runningCount
          }
          if (runningCount > maxCount) {
            maxCount = runningCount
          }

        }


      //println( "ele ", ele, " maxCount  ", maxCount, "  runningCount  ", runningCount, " insideLoop flag  ", insideLoop)
    }

    //println("maxCount" , maxCount)
    maxCount
  }
}

下面是使用上述UDF的pyspark代码。

import sys

import pyspark.sql.functions as F
from pyspark import SparkContext, SQLContext
from pyspark.sql import SparkSession
from pyspark.sql.types import *

spark = SparkSession.builder \
    .appName("MyApp") \
    .config("spark.jars", "file:/path/to/pyspark-native-udfs/releases/pyspark-native-udfs-assembly-0.1.2.jar") \
    .getOrCreate()

sc = spark.sparkContext
sqlContext = SQLContext(sc)

data1 = [
    [1, [1, 0, 0, 0, 0, 2, 1, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0]],
    [2, [0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0, 0, 2, 0, 0, 0, 0]],
    [3, [0,0,0,2,1,0,2,1,0]],
]

df1Columns = ["passengerID", "travelHist"]
df1 = sqlContext.createDataFrame(data=data1, schema=df1Columns)
df1 = df1.withColumn("travelHist", F.col("travelHist").cast("array<int>"))

df1.show(n=100, truncate=False)
df1.printSchema()


spark.udf.registerJavaFunction("count_zeros_udf", "com.help.udf.CountZeros", IntegerType())

df1.createOrReplaceTempView("given_table")

df1_array = sqlContext.sql("select *, count_zeros_udf(travelHist) as maxStreak from given_table")
print("Dataframe after applying SCALA NATIVE UDF")
df1_array.show(n=100, truncate=False)

输出量:

+-----------+------------------------------------------------------+
|passengerID|travelHist                                            |
+-----------+------------------------------------------------------+
|1          |[1, 0, 0, 0, 0, 2, 1, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0]   |
|2          |[0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0, 0, 2, 0, 0, 0, 0]|
|3          |[0, 0, 0, 2, 1, 0, 2, 1, 0]                           |
+-----------+------------------------------------------------------+

root
 |-- passengerID: long (nullable = true)
 |-- travelHist: array (nullable = true)
 |    |-- element: integer (containsNull = true)

Dataframe after applying SCALA NATIVE UDF
+-----------+------------------------------------------------------+---------+
|passengerID|travelHist                                            |maxStreak|
+-----------+------------------------------------------------------+---------+
|1          |[1, 0, 0, 0, 0, 2, 1, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0]   |7        |
|2          |[0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0, 0, 2, 0, 0, 0, 0]|3        |
|3          |[0, 0, 0, 2, 1, 0, 2, 1, 0]                           |1        |
+-----------+------------------------------------------------------+---------+

相关问题