如何在pysparkml管道中的列子集上使用standardscaler?

yws3nbqq  于 2021-07-12  发布在  Spark
关注(0)|答案(1)|浏览(323)

在我的dataframe中,有些列是连续值,而其他列只有0/1值。在使用流水线进行logistic回归之前,我想在连续列上使用standardscaler。如何实现代码?
我试着:

from pyspark.ml.feature import VectorAssembler,StandardScaler
from pyspark.ml import Pipeline,Transformer
from pyspark.sql.functions import udf,col
from pyspark.sql.types import FloatType, ArrayType
from pyspark.ml.util import DefaultParamsWritable, DefaultParamsReadable
from pyspark.ml.param.shared import HasInputCol, HasOutputCol, Param, Params, TypeConverters

class StandardScalerSubset(Transformer, DefaultParamsReadable, DefaultParamsWritable):
    """
    A custom Transformer which use StandardScaler on subset of features.
    """
    def __init__(self, to_scale_cols, remaining_cols):
        super(StandardScalerSubset, self).__init__()
        self.to_scale_cols = to_scale_cols  # continuous columns to be scaled
        self.remaining_cols = remaining_cols  # other columns

    def _transform(self, data):
        va = VectorAssembler().setInputCols(self.to_scale_cols).setOutputCol("to_scale_vector")
        data_va = va.transform(data)

        scaler = StandardScaler(inputCol="to_scale_vector", outputCol="scaled_vector", withMean=True, withStd=True)
        scaler_model = scaler.fit(data_va)
        data_scaled = scaler_model.transform(data_va)

        vector2list = udf(lambda x: x.toArray().tolist(),ArrayType(FloatType()))
        # return all columns
        data_res = data_scaled.withColumn("scaled_list", vector2list("scaled_vector")) \
            .select(self.remaining_cols
                    + [col("scaled_list").getItem(i).alias(c) for (i, c) in enumerate(self.scale_cols)])
        return data_res

输入:


# +---+---+---+---+

# |  a|  b|  c|  d|

# +---+---+---+---+

# |  1|  5| 10|  0|

# |  0| 10| 20|  1|

# |  1| 15| 25|  0|

# |  0| 30| 40|  1|

# +---+---+---+---+

输出为:


# +---+---+--------+-----+

# |  a|  d|       b|    c|

# +---+---+--------+-----+

# |  1|  0| -0.9258| -1.1|

# |  0|  1| -0.4629| -0.3|

# |  1|  0|     0.0|  0.1|

# |  0|  1|  1.3887|  1.3|

# +---+---+--------+-----+

它可以这样使用:

scalerFeatures = ['xxx']
featureAr = ['xxx']
remainingFeatures = ['xxx']
sss = StandardScalerSubset(scale_cols=scalerFeatures, remaining_cols=remainingFeatures)
vectorAssembler = VectorAssembler().setInputCols(featureArr).setOutputCol("features")
lrModel = LogisticRegression(featuresCol="features",regParam=0.1,maxIter=100,family="binomial")
pipeline = Pipeline().setStages([sss, vectorAssembler, modelObject])
pipeline.fit(trainData).write().overwrite().save(modelSavePath)

当我使用pipelinemodel.load(modelsavepath)加载模型时,我得到一个错误。我认为我应该同时实现适应和转变。但是我不知道怎么做。有人能帮我吗?谢谢。

z9zf31ra

z9zf31ra1#

评论太长了,但以下是您可以尝试的:

from pyspark.ml.feature import StandardScaler
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.classification import LogisticRegression
from pyspark.ml import Pipeline

scalerFeatures = ['b', 'c']
remainingFeatures = ['a', 'd']
featureArr = remainingFeatures + [('scaled_' + f) for f in scalerFeatures]

va1 = [VectorAssembler(inputCols=[f], outputCol=('vec_' + f)) for f in scalerFeatures]
ss = [StandardScaler(inputCol='vec_' + f, outputCol='scaled_' + f, withMean=True, withStd=True) for f in scalerFeatures]

va2 = VectorAssembler(inputCols=featureArr, outputCol="features")
lr = LogisticRegression()

stages = va1 + ss + [va2]

# I don't have a label column, but if you do, you can put lr stage at the end:

# stages = va1 + ss + [va2, lr]

p = Pipeline(stages=stages)
p.fit(df).transform(df).show()
+---+---+---+---+------+------+---------------------+----------------------+--------------------------------------------------+
|a  |b  |c  |d  |vec_b |vec_c |scaled_b             |scaled_c              |features                                          |
+---+---+---+---+------+------+---------------------+----------------------+--------------------------------------------------+
|1  |5  |10 |0  |[5.0] |[10.0]|[-0.9258200997725514]|[-1.0999999999999999] |[1.0,0.0,-0.9258200997725514,-1.0999999999999999] |
|0  |10 |20 |1  |[10.0]|[20.0]|[-0.4629100498862757]|[-0.29999999999999993]|[0.0,1.0,-0.4629100498862757,-0.29999999999999993]|
|1  |15 |25 |0  |[15.0]|[25.0]|[0.0]                |[0.09999999999999998] |[1.0,0.0,0.0,0.09999999999999998]                 |
|0  |30 |40 |1  |[30.0]|[40.0]|[1.3887301496588271] |[1.2999999999999998]  |[0.0,1.0,1.3887301496588271,1.2999999999999998]   |
+---+---+---+---+------+------+---------------------+----------------------+--------------------------------------------------+

相关问题