如何在pyspark中规范化和创建相似性矩阵?

xytpbqjk  于 2021-07-13  发布在  Spark
关注(0)|答案(1)|浏览(318)

我见过许多关于相似矩阵的堆栈溢出问题,但它们涉及rdd或其他情况,我找不到问题的直接答案,于是我决定发布一个新问题。

问题

import numpy as np
import pandas as pd
import pyspark
from pyspark.sql import functions as F, Window
from pyspark import SparkConf, SparkContext, SQLContext
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.feature import StandardScaler,Normalizer
from pyspark.mllib.linalg.distributed import IndexedRow, IndexedRowMatrix

spark = pyspark.sql.SparkSession.builder.appName('app').getOrCreate()
sc = spark.sparkContext
sqlContext = SQLContext(sc)

# pandas dataframe

pdf = pd.DataFrame({'user_id': ['user_0','user_1','user_2'],
                   'apple': [0,1,5],
                   'good banana': [3,0,1],
                   'carrot': [1,2,2]})

# spark dataframe

df = sqlContext.createDataFrame(pdf)
df.show()

+-------+-----+-----------+------+
|user_id|apple|good banana|carrot|
+-------+-----+-----------+------+
| user_0|    0|          3|     1|
| user_1|    1|          0|     2|
| user_2|    5|          1|     2|
+-------+-----+-----------+------+

使用pandas规范化并创建相似矩阵

from sklearn.preprocessing import normalize

pdf = pdf.set_index('user_id')
item_norm = normalize(pdf,axis=0) # normalize each items (NOT users)
item_sim = item_norm.T.dot(item_norm)
df_item_sim = pd.DataFrame(item_sim,index=pdf.columns,columns=pdf.columns)

                apple  good banana    carrot
apple        1.000000     0.310087  0.784465
good banana  0.310087     1.000000  0.527046
carrot       0.784465     0.527046  1.000000

问题:如何使用pyspark获得上述相似矩阵?

我想对这些数据运行kmeans。

from pyspark.ml.feature import VectorAssembler
from pyspark.ml.clustering import KMeans

# I want to do this...

model = KMeans(k=2, seed=1).fit(df.select('norm_features'))

df = model.transform(df)
df.show()

参考文献
两个pysparkDataframe的余弦相似性
Dataframe上的apache-spark-python余弦相似性

hivapdat

hivapdat1#

import pyspark.sql.functions as F

df.show()
+-------+-----+-----------+------+
|user_id|apple|good banana|carrot|
+-------+-----+-----------+------+
| user_0|    0|          3|     1|
| user_1|    1|          0|     2|
| user_2|    5|          1|     2|
+-------+-----+-----------+------+

通过取消驱动和旋转来交换行和列:

df2 = df.selectExpr(
    'user_id',
    'stack(3, ' + ', '.join(["'%s', `%s`" % (c, c) for c in df.columns[1:]]) + ') as (fruit, items)'
).groupBy('fruit').pivot('user_id').agg(F.first('items'))

df2.show()
+-----------+------+------+------+
|      fruit|user_0|user_1|user_2|
+-----------+------+------+------+
|      apple|     0|     1|     5|
|good banana|     3|     0|     1|
|     carrot|     1|     2|     2|
+-----------+------+------+------+

规格化:

df3 = df2.select(
    'fruit',
    *[
        (
            F.col(c) / 
            F.sqrt(
                sum([F.col(cc)*F.col(cc) for cc in df2.columns[1:]])
            )
        ).alias(c) for c in df2.columns[1:]
    ]
)

df3.show()
+-----------+------------------+-------------------+-------------------+
|      fruit|            user_0|             user_1|             user_2|
+-----------+------------------+-------------------+-------------------+
|      apple|               0.0|0.19611613513818404| 0.9805806756909202|
|good banana|0.9486832980505138|                0.0|0.31622776601683794|
|     carrot|0.3333333333333333| 0.6666666666666666| 0.6666666666666666|
+-----------+------------------+-------------------+-------------------+

执行矩阵乘法:

df4 = (df3.alias('t1').repartition(10)
          .crossJoin(df3.alias('t2').repartition(10))
          .groupBy('t1.fruit')
          .pivot('t2.fruit', df.columns[1:])
          .agg(F.first(sum([F.col('t1.'+c) * F.col('t2.'+c) for c in df3.columns[1:]])))
      )
df4.show()
+-----------+-------------------+-------------------+------------------+
|      fruit|              apple|        good banana|            carrot|
+-----------+-------------------+-------------------+------------------+
|      apple| 1.0000000000000002|0.31008683647302115|0.7844645405527362|
|good banana|0.31008683647302115| 0.9999999999999999|0.5270462766947298|
|     carrot| 0.7844645405527362| 0.5270462766947298|               1.0|
+-----------+-------------------+-------------------+------------------+

相关问题