在pyspark中从VectorUDT稀疏向量中提取“值”

smtd7mpg  于 2024-01-06  发布在  Spark
关注(0)|答案(2)|浏览(229)

我有一个带有2个向量列的pyspark嵌套框架。当我在笔记本中显示嵌套框架时,它会像这样打印每个向量:{“vectorType”:“sparse”,“length”:262144,“indices”:[21641],“values”:[1]}
当我打印模式时,它显示为VectorUDT。
我只需要“values”字段值作为列表或数组。我如何保存它作为一个新字段?执行“vector_field”.values似乎不起作用,因为pyspark认为它是一个String.

vsdwdz23

vsdwdz231#

spark有一个内置的ml函数用于向量到数组的转换-vector_to_array。你可以简单地传递向量列来获得与1D数组相同的结果。
这里有一个例子

  1. from pyspark.ml.linalg import SparseVector, DenseVector
  2. import pyspark.ml.functions as mfunc
  3. data_ls = [
  4. (SparseVector(3, [(0, 1.0), (2, 2.0)]),),
  5. (DenseVector([3.0, 0.0, 1.0]),),
  6. (SparseVector(3, [(1, 4.0)]),)
  7. ]
  8. spark.createDataFrame(data_ls, ['vec']). \
  9. withColumn('arr', mfunc.vector_to_array('vec')). \
  10. show(truncate=False)
  11. # +-------------------+---------------+
  12. # |vec |arr |
  13. # +-------------------+---------------+
  14. # |(3,[0,2],[1.0,2.0])|[1.0, 0.0, 2.0]|
  15. # |[3.0,0.0,1.0] |[3.0, 0.0, 1.0]|
  16. # |(3,[1],[4.0]) |[0.0, 4.0, 0.0]|
  17. # +-------------------+---------------+
  18. # root
  19. # |-- vec: vector (nullable = true)
  20. # |-- arr: array (nullable = false)
  21. # | |-- element: double (containsNull = false)

字符串

展开查看全部
kqhtkvqz

kqhtkvqz2#

我尝试使用以下向量值:

  1. smpl_data = [(SparseVector(3, {0: 1.0, 2: 2.0}),),
  2. (DenseVector([3.0, 0.0, 1.0]),),
  3. (SparseVector(3, {1: 4.0}),)]
  4. dilip_df = spark.createDataFrame(data, ["vector_field"])
  5. dilip_df.printSchema()
  6. dilip_df.show()

字符串


的数据
我定义了一个函数,这个函数以一个vector为输入,它会检查这个vector是SparseVector还是DenseVector,如果是SparseVector,它会使用vector.values.tolist()将值转换成一个列表,如果是DenseVector,它也会使用vector.values.tolist()将值转换成一个列表,如果这个vector既不是SparseVector也不是DenseVector,它返回None。使用.withColumn创建新列values_listPySpark的udf函数用于将extract_values函数注册为UDF。UDF被分配给变量extract_values_udf。udf函数的第二个参数指定UDF的返回类型,这表明UDF返回一个双精度值数组。

  1. def extract_values(vector):
  2. if isinstance(vector, SparseVector):
  3. return vector.values.tolist()
  4. elif isinstance(vector, DenseVector):
  5. return vector.values.tolist()
  6. else:
  7. return None
  8. extract_values_udf = udf(extract_values, ArrayType(DoubleType()))
  9. dilip_df = dilip_df.withColumn("values_list", extract_values_udf("vector_field"))
  10. dilip_df.show(truncate=False)
  1. +-------------------+---------------+
  2. |vector_field |values_list |
  3. +-------------------+---------------+
  4. |(3,[0,2],[1.0,2.0])|[1.0, 2.0] |
  5. |[3.0,0.0,1.0] |[3.0, 0.0, 1.0]|
  6. |(3,[1],[4.0]) |[4.0] |
  7. +-------------------+---------------+

的数据

展开查看全部

相关问题