使用sparsevector pyspark创建Dataframe

drkbr07n  于 2021-05-27  发布在  Spark
关注(0)|答案(1)|浏览(578)

假设我有一个像这样的sparkDataframe

  1. Row(Y=a, X1=3.2, X2=4.5)

我想要的是:

  1. Row(Y=a, features=SparseVector(2, {X1: 3.2, X2: 4.5})
w41d8nur

w41d8nur1#

也许这是有帮助的-
用scala编写,但可以用pyspark实现,只需很少的修改

vectorassembler从输入列创建向量

  1. val df = spark.sql("select 'a' as Y, 3.2 as X1, 4.5 as X2")
  2. df.show(false)
  3. df.printSchema()
  4. /**
  5. * +---+---+---+
  6. * |Y |X1 |X2 |
  7. * +---+---+---+
  8. * |a |3.2|4.5|
  9. * +---+---+---+
  10. *
  11. * root
  12. * |-- Y: string (nullable = false)
  13. * |-- X1: decimal(2,1) (nullable = false)
  14. * |-- X2: decimal(2,1) (nullable = false)
  15. */
  16. import org.apache.spark.ml.feature.VectorAssembler
  17. val features = new VectorAssembler()
  18. .setInputCols(Array("X1", "X2"))
  19. .setOutputCol("features")
  20. .transform(df)
  21. features.show(false)
  22. features.printSchema()
  23. /**
  24. * +---+---+---+---------+
  25. * |Y |X1 |X2 |features |
  26. * +---+---+---+---------+
  27. * |a |3.2|4.5|[3.2,4.5]|
  28. * +---+---+---+---------+
  29. *
  30. * root
  31. * |-- Y: string (nullable = false)
  32. * |-- X1: decimal(2,1) (nullable = false)
  33. * |-- X2: decimal(2,1) (nullable = false)
  34. * |-- features: vector (nullable = true)
  35. */
展开查看全部

相关问题