在PySpark ML中创建自定义Transformer

bq9c1y66  于 2023-01-01  发布在  Spark
关注(0)|答案(1)|浏览(187)

我是Spark SQL DataFrames和ML的新手(PySpark)。我如何创建一个自定义的tokenizer,例如删除停止词并使用nltk中的一些库?我可以扩展默认的吗?

8hhllhi2

8hhllhi21#

我可以扩展默认的吗?
不完全是。默认的Tokenizerpyspark.ml.wrapper.JavaTransformer的子类,和pyspark.ml.feature的其他转换器和估计器一样,将实际的处理委托给Scala的对应物。既然你想使用Python,你应该直接扩展pyspark.ml.pipeline.Transformer

  1. import nltk
  2. from pyspark import keyword_only ## < 2.0 -> pyspark.ml.util.keyword_only
  3. from pyspark.ml import Transformer
  4. from pyspark.ml.param.shared import HasInputCol, HasOutputCol, Param, Params, TypeConverters
  5. # Available in PySpark >= 2.3.0
  6. from pyspark.ml.util import DefaultParamsReadable, DefaultParamsWritable
  7. from pyspark.sql.functions import udf
  8. from pyspark.sql.types import ArrayType, StringType
  9. class NLTKWordPunctTokenizer(
  10. Transformer, HasInputCol, HasOutputCol,
  11. # Credits https://stackoverflow.com/a/52467470
  12. # by https://stackoverflow.com/users/234944/benjamin-manns
  13. DefaultParamsReadable, DefaultParamsWritable):
  14. stopwords = Param(Params._dummy(), "stopwords", "stopwords",
  15. typeConverter=TypeConverters.toListString)
  16. @keyword_only
  17. def __init__(self, inputCol=None, outputCol=None, stopwords=None):
  18. super(NLTKWordPunctTokenizer, self).__init__()
  19. self.stopwords = Param(self, "stopwords", "")
  20. self._setDefault(stopwords=[])
  21. kwargs = self._input_kwargs
  22. self.setParams(**kwargs)
  23. @keyword_only
  24. def setParams(self, inputCol=None, outputCol=None, stopwords=None):
  25. kwargs = self._input_kwargs
  26. return self._set(**kwargs)
  27. def setStopwords(self, value):
  28. return self._set(stopwords=list(value))
  29. def getStopwords(self):
  30. return self.getOrDefault(self.stopwords)
  31. # Required in Spark >= 3.0
  32. def setInputCol(self, value):
  33. """
  34. Sets the value of :py:attr:`inputCol`.
  35. """
  36. return self._set(inputCol=value)
  37. # Required in Spark >= 3.0
  38. def setOutputCol(self, value):
  39. """
  40. Sets the value of :py:attr:`outputCol`.
  41. """
  42. return self._set(outputCol=value)
  43. def _transform(self, dataset):
  44. stopwords = set(self.getStopwords())
  45. def f(s):
  46. tokens = nltk.tokenize.wordpunct_tokenize(s)
  47. return [t for t in tokens if t.lower() not in stopwords]
  48. t = ArrayType(StringType())
  49. out_col = self.getOutputCol()
  50. in_col = dataset[self.getInputCol()]
  51. return dataset.withColumn(out_col, udf(f, t)(in_col))

使用示例(数据来自ML -功能):

  1. sentenceDataFrame = spark.createDataFrame([
  2. (0, "Hi I heard about Spark"),
  3. (0, "I wish Java could use case classes"),
  4. (1, "Logistic regression models are neat")
  5. ], ["label", "sentence"])
  6. tokenizer = NLTKWordPunctTokenizer(
  7. inputCol="sentence", outputCol="words",
  8. stopwords=nltk.corpus.stopwords.words('english'))
  9. tokenizer.transform(sentenceDataFrame).show()

有关自定义Python Estimator,请参见How to Roll a Custom Estimator in PySpark mllib
此答案取决于内部API,并兼容Spark 2.0.3、2.1.1、2.2.0或更高版本(SPARK-19348)。有关与之前Spark版本兼容的代码,请参见revision 8

展开查看全部

相关问题