pandas 目标是二进制的,但是我得到了“ValueError:支持的目标类型包括:(“binary”,“multiclass'),改为”unknown“,”

rta7y2nd  于 2022-12-31  发布在  其他
关注(0)|答案(1)|浏览(208)

我在使用sklearn KFold的简单ML模型中遇到问题
我使用以下代码对目标值进行分类:

# Import the DB
df = pd.read_csv("DB_ML_TJA20182019.csv")
#Transform continuous target into binary
category = pd.cut(df.length,bins=[0,4,100],labels=[0,1])
df.insert(18,"length_over", category)

现在,如果我打开csv,我可以看到一个添加的列(length_over,第18列,从0开始计数),其中的二进制变量是由length列的二进制化生成的。然后,我将数据集保存为一个新文件,并使用以下代码将其拆分为测试验证子集:

# Save the dataset with binary target
df.to_csv(r'DB_ML_TJA20182019_multilabel.csv', index = False)

# Load dataset for ML modeling (already imputed)
url = 'DB_ML_TJA20182019_multilabel.csv'
names = ...
dataset = read_csv(url, names=features, skiprows=1)

# Split-out validation dataset
array = dataset.values
X = array[:,0:18]
y = array[:,18]
X_train, X_validation, Y_train, Y_validation = train_test_split(X, y, test_size=0.30, random_state=1)

然而,在进行模型评估和比较之前,我得到了错误:Out: "ValueError: Supported target types are: ('binary', 'multiclass'). Got 'unknown' instead."
我还检查了目标的类型

#Check the type of target
from sklearn.utils.multiclass import type_of_target
print(type_of_target(y))

结果是unknown
可能是什么问题?当我打开csv时,目标是二进制的,但是函数将其作为未知得到。
数据类型为int64

ryevplcw

ryevplcw1#

我很晚才开始,但在为多标签分类任务准备数据集并使用MultilabelStratifiedKFold()时遇到了这个错误。
本质上,要确保您拥有的每个标签都是具有正确数据类型(例如int)的numpy数组。
在我的例子中,在pd.DataFrame上执行一些操作之后,y label是一个pandas.Series,其中包含list作为元素(=标签),而不是np.array()
我解决了它:

y = df["label_column"].to_numpy() 
y = [np.array(label) for label in y]

相关问题