我有一个sparkDataframe,看起来像这样:
+---+-----------+-------------------------+---------------+
| id| Phase | Switch | InputFileName |
+---+-----------+-------------------------+---------------+
| 1| 2| 1| fileA|
| 2| 2| 1| fileA|
| 3| 2| 1| fileA|
| 4| 2| 0| fileA|
| 5| 2| 0| fileA|
| 6| 2| 1| fileA|
| 11| 2| 1| fileB|
| 12| 2| 1| fileB|
| 13| 2| 0| fileB|
| 14| 2| 0| fileB|
| 15| 2| 1| fileB|
| 16| 2| 1| fileB|
| 21| 4| 1| fileB|
| 22| 4| 1| fileB|
| 23| 4| 1| fileB|
| 24| 4| 1| fileB|
| 25| 4| 1| fileB|
| 26| 4| 0| fileB|
| 31| 1| 0| fileC|
| 32| 1| 0| fileC|
| 33| 1| 0| fileC|
| 34| 1| 0| fileC|
| 35| 1| 0| fileC|
| 36| 1| 0| fileC|
+---+-----------+-------------------------+---------------+
对于每个组(一组 InputFileName
以及 Phase
)我需要运行一个验证函数来检查 Switch
在组的开始和结束处等于1,并在这两者之间的任何一点处转换为0。函数应将验证结果添加为新列。预期结果如下:(差距只是为了突出不同的群体)
+---+-----------+-------------------------+---------------+--------+
| id| Phase | Switch | InputFileName | Valid |
+---+-----------+-------------------------+---------------+--------+
| 1| 2| 1| fileA| true |
| 2| 2| 1| fileA| true |
| 3| 2| 1| fileA| true |
| 4| 2| 0| fileA| true |
| 5| 2| 0| fileA| true |
| 6| 2| 1| fileA| true |
| 11| 2| 1| fileB| true |
| 12| 2| 1| fileB| true |
| 13| 2| 0| fileB| true |
| 14| 2| 0| fileB| true |
| 15| 2| 1| fileB| true |
| 16| 2| 1| fileB| true |
| 21| 4| 1| fileB| false|
| 22| 4| 1| fileB| false|
| 23| 4| 1| fileB| false|
| 24| 4| 1| fileB| false|
| 25| 4| 1| fileB| false|
| 26| 4| 0| fileB| false|
| 31| 1| 0| fileC| false|
| 32| 1| 0| fileC| false|
| 33| 1| 0| fileC| false|
| 34| 1| 0| fileC| false|
| 35| 1| 0| fileC| false|
| 36| 1| 0| fileC| false|
+---+-----------+-------------------------+---------------+--------+
我以前用pyspark和pandas udf解决了这个问题:
df = df.groupBy("InputFileName", "Phase").apply(validate_profile)
@pandas_udf(schema, PandasUDFType.GROUPED_MAP)
def validate_profile(df: pd.DataFrame):
first_valid = True if df["Switch"].iloc[0] == 1 else False
during_valid = (df["Switch"].iloc[1:-1] == 0).any()
last_valid = True if df["Switch"].iloc[-1] == 1 else False
df["Valid"] = first_valid & during_valid & last_valid
return df
但是,现在我需要在scala中重写它。我只想知道最好的办法。
我正在尝试使用窗口函数获取每个组的第一个和最后一个ID:
val minIdWindow = Window.partitionBy("InputFileName", "Phase").orderBy("id")
val maxIdWindow = Window.partitionBy("InputFileName", "Phase").orderBy(col("id").desc)
然后我可以将min和max id添加为单独的列并使用 when
获取 Switch
:
df.withColumn("MinId", min("id").over(minIdWindow))
.withColumn("MaxId", max("id").over(maxIdWindow))
.withColumn("Valid", when(
col("id") === col("MinId"), col("Switch")
).when(
col("id") === col("MaxId"), col("Switch")
))
这会得到起始值和结束值,但我不确定如何检查 Switch
中间等于0。我使用窗口函数的方法正确吗?或者你会推荐一个替代方案吗?
1条答案
按热度按时间du7egjpx1#
试试这个,
输出: