将PySpark和DBSCAN与panda_udf结合使用

xytpbqjk  于 2022-12-13  发布在  Apache
关注(0)|答案(2)|浏览(190)

我正在阅读panda_udf的文档:分组Map
我很好奇如何将sklearn DBSCAN添加到其中,例如,我有一个数据集:

data = [(1, 11.6133, 48.1075),
         (1, 11.6142, 48.1066),
         (1, 11.6108, 48.1061),
         (1, 11.6207, 48.1192),
         (1, 11.6221, 48.1223),
         (1, 11.5969, 48.1276),
         (2, 11.5995, 48.1258),
         (2, 11.6127, 48.1066),
         (2, 11.6430, 48.1275),
         (2, 11.6368, 48.1278),
         (2, 11.5930, 48.1156)]

df = spark.createDataFrame(data, ["id", "X", "Y"])

我希望groupby id,并分别对每个id执行DBSCAN群集。

@pandas_udf("id long, X double, Y double", PandasUDFType.GROUPED_MAP)
def dbscan_udf(...):
    # pdf is a pandas.DataFrame
    v = ...
    return ...

df.groupby("id").apply(dbscan_udf).show()

我要查找的输出是具有cluster列的原始数据集,其中显示了具有相同id的彼此接近的点。
感谢您的帮助!

5vf7fwbs

5vf7fwbs1#

所以我自己设法做到了这一点:

from pyspark.sql.types import StructType, StructField, DoubleType, StringType, IntegerType
from pyspark.sql.functions import *
from sklearn.cluster import DBSCAN
import pandas as pd

data = [(1, 11.6133, 48.1075),
         (1, 11.6142, 48.1066),
         (1, 11.6108, 48.1061),
         (1, 11.6207, 48.1192),
         (1, 11.6221, 48.1223),
         (1, 11.5969, 48.1276),
         (2, 11.5995, 48.1258),
         (2, 11.6127, 48.1066),
         (2, 11.6430, 48.1275),
         (2, 11.6368, 48.1278),
         (2, 11.5930, 48.1156)]

df = spark.createDataFrame(data, ["id", "X", "Y"])

output_schema = StructType(
            [
                StructField('id', IntegerType()),
                StructField('X', DoubleType()),
                StructField('Y', DoubleType()),
                StructField('cluster', IntegerType())
             ]
    )

@pandas_udf(output_schema, PandasUDFType.GROUPED_MAP)
def dbscan_pandas_udf(data):
    data["cluster"] = DBSCAN(eps=5, min_samples=3).fit_predict(data[["X", "Y"]])
    result = pd.DataFrame(data, columns=["id", "X", "Y", "cluster"])
    return result

df.groupby("id").apply(dbscan_pandas_udf).show()

希望它能在未来的某个人身上有所帮助。

f45qwnt8

f45qwnt82#

我相信是这样的。

# Sum
df.groupBy('id').sum().show()

或者,如果你的Spark版本太旧,试试这个。

(df
    .groupBy("id")
    .agg(func.col("id"), func.sum("order_item"))
    .show())

有关DBSCAN的一些信息,请参见下面的链接。
https://github.com/alitouka/spark_dbscan

相关问题