python Sklearn Pipeline连接轴的所有输入数组维度必须完全匹配

myss37ts  于 2023-01-01  发布在  Python
关注(0)|答案(2)|浏览(161)
import pandas as pd
from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer
from sklearn.pipeline import Pipeline
from sklearn.svm import LinearSVC
from sklearn.preprocessing import MinMaxScaler
from sklearn.compose import ColumnTransformer

data = [[1, 3, 4, 'text', 'pos'], [9, 3, 6, 'text more', 'neg']]
data = pd.DataFrame(data, columns=['Num1', 'Num2', 'Num3', 'Text field', 'Class'])

tweet_text_transformer = Pipeline(steps=[
    ('count_vectoriser', CountVectorizer()),
    ('tfidf', TfidfTransformer())
])

numeric_transformer = Pipeline(steps=[
    ('scaler', MinMaxScaler())
])

preprocessor = ColumnTransformer(transformers=[
    # (name, transformer, column(s))
    ('tweet', tweet_text_transformer, ['Text field']),
    ('numeric', numeric_transformer, ['Num1', 'Num2', 'Num3'])
])

pipeline = Pipeline(steps=[
    ('preprocessor', preprocessor),
    ('classifier', LinearSVC())
])

X_train = data.loc[:, 'Num1':'Text field']
y_train = data['Class']
pipeline.fit(X_train, y_train)

我不明白这个错误是从哪里来的:
ValueError:连接轴的所有输入数组维度必须完全匹配,但在维度0上,索引0处的数组大小为1,索引1处的数组大小为2

6kkfgxo0

6kkfgxo01#

原因

问题出在preprocessor管道中。此管道的工作方式是将tweet_text_transformer的输出和numeric_transformer的输出水平堆叠。要成功实现这一点,两个输出(tweet_text_transformer和numeric_transformer)必须具有相同的行数(即:轴0或维-0中的元素数)
但是当上述流水线执行tweet_text_processor时,尽管我们期望它实际上给予具有4个元素的2 * 2矩阵,因为CountVectorizer将输出存储为稀疏矩阵,它移除矩阵中的任何零(为了保存内存)这将数组减少到2*2矩阵,但其中只有3个元素,并且当它与numeric_的输出堆叠时transformer它不满足上述条件(因为数值transformer在轴0上有两个元素,而twwet_text_processor没有)
Output of the explination
溶液

  • 创建一个自定义转换器,将此稀疏矩阵转换为numpy数组
  • 此外,由于只有一列,因此挤压Pandas Dataframe 以将其转换为Panadas系列
import pandas as pd
from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.pipeline import Pipeline
from sklearn.svm import LinearSVC
from sklearn.preprocessing import MinMaxScaler
from sklearn.compose import ColumnTransformer

data = [[1, 3, 4, 'text', 'pos'], [9, 3, 6, 'text more', 'neg']]
data = pd.DataFrame(data, columns=['Num1', 'Num2', 'Num3', 'Text field', 'Class'])


class TweetTextProcessor(BaseEstimator, TransformerMixin):
    def __init__(self):
        self.tweet_text_transformer = Pipeline(steps=[
        ('count_vectoriser', CountVectorizer()),
        ('tfidf', TfidfTransformer())    ])
       
        
    def fit(self, X, y=None):
        return self

    def transform(self, X, y=None):
       
        return  self.tweet_text_transformer.fit_transform(X.squeeze()).toarray()
        



numeric_transformer = Pipeline(steps=[
    ('scaler', MinMaxScaler())
])

preprocessor = ColumnTransformer(transformers=[
    ('tweet', TweetTextProcessor(), ['Text field']),
    ('numeric', numeric_transformer, ['Num1', 'Num2', 'Num3'])
])

pipeline = Pipeline(steps=[
    ('preprocessor', preprocessor),
    ('classifier', LinearSVC())
])

X_train = data.loc[:, 'Num1':'Text field']
y_train = data['Class']
pipeline.fit(X_train, y_train)

上述代码应该可以工作,如果解释不清楚(希望如此),请告诉我

pwuypxnk

pwuypxnk2#

我实现了您的代码解决方案,将稀疏矩阵转换为数组,并修复了错误,但是,当我调用predict时,它显示了另一个错误

model = pipeline.fit(X_train,y_train)
y_pred = model.predict(X_test)

它给予了我这个错误
ValueError:X每个样本有574个特征;预期493
我的理解是,在这种情况下,它没有使用已训练的矢量器模型,而是在X_test数据集上训练了一个新模型。
注:需要为BaseEstimator、TransformerMixin添加导入语句

更新:

要解决此问题,请使用FunctionTransformer而不是定义类
使用FunctionTransformer而不是定义类

from sklearn.preprocessing import FunctionTransformer

vectorizer_params = dict(ngram_range=(1, 2), min_df=5, max_df=0.8)

TweetTextProcessor = Pipeline(steps=[
    ("squeez", FunctionTransformer(lambda x: x.squeeze())),
    ("vect", CountVectorizer(**vectorizer_params)),
    ("tfidf", TfidfTransformer()),
    ("toarray", FunctionTransformer(lambda x: x.toarray())),
])

numeric_transformer = Pipeline(steps=[
    ('scaler', MinMaxScaler())
])

preprocessor = ColumnTransformer(transformers=[
    ('tweet', TweetTextProcessor, ['Text field']),
    ('numeric', numeric_transformer, ['Num1', 'Num2', 'Num3'])
])

pipeline = Pipeline(steps=[
    ('preprocessor', preprocessor),
    ('classifier', LinearSVC())
])

相关问题