keras Numpy.int 对象不可迭代

5cg8jx4n  于 2023-02-08  发布在  其他
关注(0)|答案(1)|浏览(161)

我用keras训练了一个模型,当我想显示分数时,我在prediction_cat行得到一个numpy错误,说numpy.int32对象是不可迭代的。任何帮助都是很好的。谢谢。
这是单元格的代码:

from sklearn.metrics import accuracy_score, auc, f1_score, recall_score

prediction = model.predict(test_img_pca)

prediction_cat = [np.where(row == max(row))[0][0] for row in prediction]

acc_krr = accuracy_score(y_test_cat, prediction_cat)
print("Accuracy: ", acc_krr)

rcl_krr = recall_score(y_test_cat, prediction_cat, average = None)
print("Recall: ", rcl_krr)

f1_krr = f1_score(y_test_cat, prediction_cat, average = None)
print("F1: ", f1_krr)

这是我得到的错误:

TypeError                                 Traceback (most recent call last)
~\AppData\Local\Temp\ipykernel_14244\3875264768.py in <module>
      3 prediction = model.predict(test_img_pca)
      4 
----> 5 prediction_cat = [np.where(row == max(row))[0][0] for row in prediction]
      6 
      7 acc_krr = accuracy_score(y_test_cat, prediction_cat)

~\AppData\Local\Temp\ipykernel_14244\3875264768.py in <listcomp>(.0)
      3 prediction = model.predict(test_img_pca)
      4 
----> 5 prediction_cat = [np.where(row == max(row))[0][0] for row in prediction]
      6 
      7 acc_krr = accuracy_score(y_test_cat, prediction_cat)

TypeError: 'numpy.int32' object is not iterable
e5nszbig

e5nszbig1#

这个错误告诉你prediction对象不是一个可迭代的对象,而是一个numpy.int32类型的对象。
尝试打印prediction并查看它的外观以更好地了解它是什么

相关问题