python—计算pyspark df列中子字符串列表的出现次数

bz4sfanl  于 2021-06-24  发布在  Hive
关注(0)|答案(2)|浏览(586)

我想计算子字符串列表的出现次数,并基于pyspark df中包含长字符串的列创建一个列。

Input:          
       ID    History

       1     USA|UK|IND|DEN|MAL|SWE|AUS
       2     USA|UK|PAK|NOR
       3     NOR|NZE
       4     IND|PAK|NOR

 lst=['USA','IND','DEN']

Output :
       ID    History                      Count

       1     USA|UK|IND|DEN|MAL|SWE|AUS    3
       2     USA|UK|PAK|NOR                1
       3     NOR|NZE                       0
       4     IND|PAK|NOR                   1
hiz5n14c

hiz5n14c1#


# Importing requisite packages and creating a DataFrame

from pyspark.sql.functions import split, col, size, regexp_replace
values = [(1,'USA|UK|IND|DEN|MAL|SWE|AUS'),(2,'USA|UK|PAK|NOR'),(3,'NOR|NZE'),(4,'IND|PAK|NOR')]
df = sqlContext.createDataFrame(values,['ID','History'])
df.show(truncate=False)
+---+--------------------------+
|ID |History                   |
+---+--------------------------+
|1  |USA|UK|IND|DEN|MAL|SWE|AUS|
|2  |USA|UK|PAK|NOR            |
|3  |NOR|NZE                   |
|4  |IND|PAK|NOR               |
+---+--------------------------+

我们的想法是根据这三条线来分割线 delimiters : lst=['USA','IND','DEN'] 然后计算生成的子字符串数。
例如:;绳子 USA|UK|IND|DEN|MAL|SWE|AUS 分裂成- , , |UK| , | , |MAL|SWE|AUS . 因为,创建了4个子字符串,有3个分隔符匹配,所以 4-1 = 3 给出列字符串中出现的这些字符串的计数。
我不确定spark中是否支持多字符分隔符,因此作为第一步,我们将替换列表中这3个子字符串中的任何一个 ['USA','IND','DEN'] 使用标志/伪值 % . 你也可以用别的东西。下面的代码执行此操作 replacement -

df = df.withColumn('History_X',col('History'))
lst=['USA','IND','DEN']
for i in lst:
    df = df.withColumn('History_X', regexp_replace(col('History_X'), i, '%'))
df.show(truncate=False)
+---+--------------------------+--------------------+
|ID |History                   |History_X           |
+---+--------------------------+--------------------+
|1  |USA|UK|IND|DEN|MAL|SWE|AUS|%|UK|%|%|MAL|SWE|AUS|
|2  |USA|UK|PAK|NOR            |%|UK|PAK|NOR        |
|3  |NOR|NZE                   |NOR|NZE             |
|4  |IND|PAK|NOR               |%|PAK|NOR           |
+---+--------------------------+--------------------+

最后,我们计算 splitting 它首先与 % 作为分隔符,然后计算用 size 函数,最后从中减去1。

df = df.withColumn('Count', size(split(col('History_X'), "%")) - 1).drop('History_X')
df.show(truncate=False)
+---+--------------------------+-----+
|ID |History                   |Count|
+---+--------------------------+-----+
|1  |USA|UK|IND|DEN|MAL|SWE|AUS|3    |
|2  |USA|UK|PAK|NOR            |1    |
|3  |NOR|NZE                   |0    |
|4  |IND|PAK|NOR               |1    |
+---+--------------------------+-----+
3zwjbxry

3zwjbxry2#

如果您使用的是spark 2.4+,则可以尝试使用spark sql高阶函数 filter() :

from pyspark.sql import functions as F

>>> df.show(5,0)
+---+--------------------------+
|ID |History                   |
+---+--------------------------+
|1  |USA|UK|IND|DEN|MAL|SWE|AUS|
|2  |USA|UK|PAK|NOR            |
|3  |NOR|NZE                   |
|4  |IND|PAK|NOR               |
+---+--------------------------+

df_new = df.withColumn('data', F.split('History', '\|')) \
           .withColumn('cnt', F.expr('size(filter(data, x -> x in ("USA", "IND", "DEN")))'))

>>> df_new.show(5,0)
+---+--------------------------+----------------------------------+---+
|ID |History                   |data                              |cnt|
+---+--------------------------+----------------------------------+---+
|1  |USA|UK|IND|DEN|MAL|SWE|AUS|[USA, UK, IND, DEN, MAL, SWE, AUS]|3  |
|2  |USA|UK|PAK|NOR            |[USA, UK, PAK, NOR]               |1  |
|3  |NOR|NZE                   |[NOR, NZE]                        |0  |
|4  |IND|PAK|NOR               |[IND, PAK, NOR]                   |1  |
+---+--------------------------+----------------------------------+---+

我们第一次分田的地方 History 一个名为 data 然后使用filter函数:

filter(data, x -> x in ("USA", "IND", "DEN"))

仅检索满足条件的数组元素: IN ("USA", "IND", "DEN") ,之后,我们用 size() 功能。
更新:添加了另一种使用array_contains()的方法,该方法应适用于旧版本spark:

lst = ["USA", "IND", "DEN"]

df_new = df.withColumn('data', F.split('History', '\|')) \
           .withColumn('Count', sum([F.when(F.array_contains('data',e),1).otherwise(0) for e in lst]))

注意:数组中的重复条目将被跳过,此方法仅计算唯一的国家/地区代码。

相关问题