keras 如何在使用model.predict()时获取标签

pxy2qtax  于 2022-12-04  发布在  其他
关注(0)|答案(1)|浏览(253)

在我的项目中,我有很多这样的情况:我有一个Dataset示例,我需要从某个模型中获取对数据集中每一项的预测。
model.predict() API对此进行了完美的优化,如文档中所示。然而,似乎有一个主要问题。我还碰巧需要标签来与预测值进行比较,即数据集包含x,y对,并且我希望在预测完成后以(y_predicted, y)对结束,我想不出一个清晰的方法来“分割”数据集,以便将x输入模型,而保留y,以便与预测的y重新连接。
编辑:我知道手动迭代数据集并直接调用模型是非常简单的,例如。

for x, y in dataset:
    y_pred = model(x)
    result.append((y, y_pred))

但是,这似乎比使用内置predict()要慢一些,因为Tensorflow无法对输入管道进行多线程/优化。
有没有人有一个很好的方法来完成这一点?

2jcobegt

2jcobegt1#

考虑到您提到的问题,最好覆盖predict以满足您的需要。但您实际上并不需要覆盖该函数,只需覆盖该函数调用的predict_step即可。只需使用此类而不是Model

class MyModel(tf.keras.Model):
    def predict_step(self, data):
        x, y = data
        return self(x, training=False), y

如果您的模型当前是Sequential,则从它继承。基本上,我对默认实现所做的唯一更改是将, y添加到模型调用结果中。请注意,这也做了一些假设,例如您的数据集由(input, label)批处理对组成。您可能需要稍微调整它以满足您的需要。下面是一个最小的示例:

import tensorflow as tf
import numpy as np

(imgs, lbls), (te_imgs, te_lbls) = tf.keras.datasets.mnist.load_data()

imgs = imgs.astype(np.float32).reshape((-1, 784)) / 255.
te_imgs = te_imgs.astype(np.float32).reshape((-1, 784)) / 255.

lbls = lbls.astype(np.int32)
te_lbls = te_lbls.astype(np.int32)

tr_data = tf.data.Dataset.from_tensor_slices((imgs, lbls)).shuffle(60000).batch(128)
te_data = tf.data.Dataset.from_tensor_slices((te_imgs, te_lbls)).batch(128)

class MyModel(tf.keras.Model):
    def predict_step(self, data):
        x, y = data
        return self(x, training=False), y

inp = tf.keras.Input((784,))

logits = tf.keras.layers.Dense(10)(inp)

model = MyModel(inp, logits)

opt = tf.keras.optimizers.Adam()
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

model.compile(loss=loss, optimizer=opt)

something = model.predict(te_data)

print(something[0].shape, something[1].shape)

这显示((10000, 10), (10000,))--predict现在返回outputs, labels的元组(这可以通过检查返回的标签并与测试集中的图像进行比较来确认)。

相关问题