Pyspark中的点积与MLLIB

i1icjdpr  于 12个月前  发布在  Spark
关注(0)|答案(5)|浏览(116)

我在pyspark中有一个非常简单的框架,类似于这样:

from pyspark.sql import Row
from pyspark.mllib.linalg import DenseVector

row = Row("a", "b")
df = spark.sparkContext.parallelize([
    offer_row(DenseVector([1, 1, 1]), DenseVector([1, 0, 0])),
]).toDF()

字符串
我想计算这些向量的点积,而不需要调用UDF。
spark MLLIB documentation引用了DenseVectors上的dot方法,但如果我尝试按如下方式应用它:

df_offers = df_offers.withColumn("c", col("a").dot(col("b")))


我得到的错误如下:

TypeError: 'Column' object is not callable


有谁知道这些mllib方法是否可以在DataFrame对象上调用?

zf9nrax1

zf9nrax11#

在这里,您将dot方法应用于列而不是DenseVector,这确实不起作用:

df_offers = df_offers.withColumn("c", col("a").dot(col("b")))

字符串
你必须使用一个udf:

from pyspark.sql.functions import udf, array
from pyspark.sql.types import DoubleType

def dot_fun(array):
    return array[0].dot(array[1])

dot_udf = udf(dot_fun, DoubleType())

df_offers = df_offers.withColumn("c", dot_udf(array('a', 'b')))

9rnv2umw

9rnv2umw2#

没有。你必须使用udf:

from pyspark.sql.functions import udf

@udf("double")
def dot(x, y):
    if x is not None and y is not None:
        return float(x.dot(y))

字符串

2vuwiymt

2vuwiymt3#

您可以在不使用UDF的情况下将两列相乘,方法是先将它们转换为BlockMatrix,然后像下面的示例那样将它们相乘

from pyspark.mllib.linalg.distributed import IndexedRow, IndexedRowMatrix

ac = offer_row.select('a')
bc = offer_row.select('a')
mata = IndexedRowMatrix(ac.rdd.map(lambda row: IndexedRow(*row)))
matb = IndexedRowMatrix(bc.rdd.map(lambda row: IndexedRow(*row)))

ma = mata.toBlockMatrix(100,100)
mb = matb.toBlockMatrix(100,100)

ans = ma.multiply(mb.transpose())

字符串

zzlelutf

zzlelutf4#

这是一个hack,但可能比Python udf性能更好。你可以把点积转换成SQL:

import pandas as pd
from pyspark.sql.functions import expr

coefs = pd.Series({'a': 1.0, 'b': 2.0})
dot_sql = ' + '.join(
    '{} * {}'.format(coef, colname)
    for colname, coef
    in coefs.items()
)
dot_expr = expr(dot_sql)

df.withColumn('dot_product', dot_expr)

字符串

hc2pp10m

hc2pp10m5#

作为对第一个答案的评论,我现在得到的是:AttributeError: 'list' object has no attribute 'dot'
即使我调用np.dot(a, b),也总是有一个类型错误。类似于:
Job aborted due to stage failure: Task 0 in stage 225.0 failed 4 times, most recent failure: ... : net.razorvine.pickle.PickleException: expected zero arguments for construction of ClassDict (for numpy.dtype)
我不得不将返回值更改为item(),这样它就不再在numpy中了。工作解决方案:

@F.udf(returnType=T.FloatType())
def dot_udf(arr1, arr2):
    if arr1 is not None and arr2 is not None:
        return np.dot(arr1, arr2).astype(np.float32).item()

df_ = df.withColumn("c", dot_udf(F.col('a'), F.col('b')))

字符串

相关问题