Dtreeviz -属性错误:'DataFrame'对象没有属性'dtype' Python . Scikit-learn

jxct1oxe  于 2022-12-27  发布在  Python
关注(0)|答案(1)|浏览(342)

我正在尝试用dtreeviz做一个决策树

import pandas as pd
from sklearn import preprocessing, tree
from dtreeviz.trees import dtreeviz

我有一个PandasDF像:
DF1:

id | age | gender | platform | Customer 
1  | 34  | M      | Web      | User 
2  | 37  | F      | App      | Customer

我创建了一些虚拟变量

X = df1[['age', 'gender', 'portfolio_type', 'platform']]
X = pd.get_dummies(data=X, drop_first=True)

Y = df1[[ 'Customer']]
Y = pd.get_dummies(data=Y, drop_first=True)

然后创建测试集和训练集。

from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.4, random_state=101)

如果我创建一个像这样的决策树,它是有效的:

import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn import tree
from dtreeviz.trees import *

#fit the classifier
clf = tree.DecisionTreeClassifier(max_depth=3, random_state=42)
clf.fit(X_train, y_train)

树.绘图树(clf)

即视图()
如果我这样做,它也会起作用:

tree.plot_tree(clf,
               feature_names = X.columns, 
               class_names= df['Customer'],
               rounded=True, 
               filled = True,
               fontsize=7
               );

但是如果我尝试使用dtreeviz,我会得到错误:

viz = dtreeviz(classifier, 
               X[["age",    "gender_M", "portfolio_type_esg",   "platform_web"]], 
               Y,
               target_name='Customer',
               feature_names = X.columns, 
               class_names= list(set(df['Customer']))
              )  
              
viz.view()


AttributeError: 'DataFrame' object has no attribute 'dtype'

为什么会这样?我能做些什么?

nle07wnf

nle07wnf1#

我不能重复这个。至少当scikit-learn分类器适合 Dataframe 时,dtreeviz==1.4.1看起来可以工作。

地雷风险评估

from sklearn.tree import DecisionTreeRegressor
from sklearn.datasets import fetch_california_housing
from dtreeviz.trees import dtreeviz

housing = fetch_california_housing(as_frame=True)
regr = DecisionTreeRegressor(max_depth=2).fit(housing.data, housing.target)

viz = dtreeviz(regr,
               housing.data,               # pandas.DataFrame
               housing.target,             # pandas.Series
               target_name="MedHouseVal",
               feature_names=list(housing.data.columns))
viz.view()

相关问题