tensorflow 使用tf.data.Dataset预取会使模型性能过拟合?

np8igboo  于 2023-04-21  发布在  其他
关注(0)|答案(1)|浏览(113)

我尝试在Tensorflow 2.5.0中用一些序列图像数据集训练一个简单的LRCN模型。训练性能很好,前5个epoch的训练和验证精度都提高到了0.9倍,训练和验证损失在训练过程中不断下降。
然后,我尝试使用prefetch()优化数据管道。我使用的数据集是标题和信息写入.csv文件的序列图像(.png)。所以我做了如下的数据生成器:

def setData(data):
X, y = [], []

name = data.loc['fileName'].values.tolist()[0]
info1 = data.loc['info1'].values.tolist()[0]
info2 = data.loc['info2'].values.tolist()[0]
info3 = data.loc['info3'].values.tolist()[0]

if os.path.isfile(filepath + name) == False:
    print('No file for img')
    return

try:
    img = np.load(filepath + fName)
except:
    print(name)  

if info1 in info_list:  
    X.append(img)

    if info2 == 'True':
        y.append(0)

    else:
        y.append(1)

X = np.array(X)
X = np.reshape(X, (3, 128, 128, 1)).astype(np.float64)
y = np_utils.to_categorical(y, num_classes = 2)
y = np.reshape(y, (2)).astype(np.float64)

return X, y

我添加了数据生成器加载函数,如下所示:

def generatedata(i):
    i = i.numpy()
    X_batch, y_batch = setData(pd.DataFrame(traindata.iloc[i]))

    return X_batch, y_batch

最后,我使用map预取数据集

z = list(range(len(traindata[])))
trainDataset = tf.data.Dataset.from_generator(lambda: z, tf.uint8)
trainDataset = trainDataset.map(lambda i: tf.py_function(func = generatedata,
                                                         inp = [i],
                                                         Tout = [tf.float32, tf.float32]),
                                                         num_parallel_calls = tf.data.AUTOTUNE)

在我应用了这些步骤之后,训练精度在第一个epoch中达到0.9倍,在前3-5个epoch中达到1.0倍,验证精度保持在0.6倍左右,验证损失在x. x上不断增长。
我相信预取只会改变数据管道,而不会影响模型性能,所以我不确定是什么导致了这种过拟合(可能?)样的结果。我遵循了Tensorflow文档中提到的预取步骤的每一步。尽管如此,由于我对Tensorflow不是很熟悉,可能会有一些错误。
有没有我漏掉的台词?任何意见都是非常好的。提前感谢。

rdlzhqv9

rdlzhqv91#

事实证明,py_function()使tf.graph堆叠在先前的结果上,导致模型过度拟合。
我修改了prefetch函数来获取generator函数,并正常工作。虽然我检查了tensorflow文档,但我并没有完全警告这种情况,而是在tensorflow github页面上发现了这一点。
对于那些和我有同样问题的人,试着仔细复习库模块函数。

相关问题