我的整个数据集是8个类,每个类100个被试,我做分类的时候,我把它拆分了,测试集是不平衡的。
我确实分裂了:
_train, X_test, y_train, y_test = train_test_split(data.iloc[:, :-1], data.iloc[:, -1], test_size=0.2, random_state=42)
RFC的混淆矩阵为enter image description here,例如:二等舱只有10个,为什么要平衡呢?
谢谢大家。
我的整个数据集是8个类,每个类100个被试,我做分类的时候,我把它拆分了,测试集是不平衡的。
我确实分裂了:
_train, X_test, y_train, y_test = train_test_split(data.iloc[:, :-1], data.iloc[:, -1], test_size=0.2, random_state=42)
RFC的混淆矩阵为enter image description here,例如:二等舱只有10个,为什么要平衡呢?
谢谢大家。
1条答案
按热度按时间wj8zmpe11#
scikit-learn中的
train_test_split
函数以随机的方式分割类数。要保持测试集和训练集中的类数相等,需要添加“stratify”参数。
See documentation