如何在pysparkml中找到向量的argmax

1cosmwyk  于 2021-05-27  发布在  Spark
关注(0)|答案(2)|浏览(600)

我的模型输出了一个densevector列,我想找到argmax。这个页面建议这个函数应该是可用的,但是我不确定语法应该是什么。
它是 df.select("mycolumn").argmax() ?

uqdfh47h

uqdfh47h1#

我在python中找不到argmax操作的文档。但是你可以把它们转换成数组
对于pyspark 3.0.0

  1. from pyspark.ml.functions import vector_to_array
  2. tst_arr = tst_df.withColumn("arr",vector_to_array(F.col('vector_column')))
  3. tst_max=tst_arr.withColumn("max_value",F.array_max("arr"))
  4. tst_max_exp = tst_max.select('*',F.posexplode("arr"))
  5. tst_fin = tst_max_exp.where('col==max_value')

对于pyspark<3.0.0

  1. from pyspark.sql.functions import udf
  2. @udf
  3. def vect_argmax(row):
  4. row_arr = row.toArray()
  5. max_pos = np.argmax(row_arr)
  6. return(int(max_pos))
  7. tst_fin = tst_df.withColumn("argmax",vect_argmax(F.col('probability')))
展开查看全部
s4chpxco

s4chpxco2#

你试过了吗

  1. from pyspark.sql.functions import col
  2. df.select(col("mycolumn").argmax())

相关问题