我正在使用sklearn。RandomForestClassifier,我有11个类。我的数据在 Dataframe 中,所有变量都是热编码的。类是字符串,如“Potato”,“Tomato”,“Straberry”等。
当我尝试打印混淆矩阵时,我得到了以下内容:
print(pd.crosstab(y_test, y_pred))
Error: If using all scalar values, you must pass an index
传递索引时:
print(pd.crosstab(y_test, y_pred, index = [0]))
Error:crosstab() got multiple values for argument 'index'
解决这一问题的最佳方式是什么?
1条答案
按热度按时间oalqel3c1#
这个错误说你需要传递参数“index”给crosstab,而不是一个索引,它可以帮助你遍历一个列表。你可以找到正确的方法和更多的细节here
您还可以使用以下代码在Sci-Kit Learn中绘制混淆矩阵
这段代码从用于混淆矩阵的训练数据中获取所有标签
此代码导入混淆矩阵并绘制它。
plt.cm.Blues
用于配色方案,clf
是您的分类器,请确保使用您命名的分类器更改它。