我有下面的代码,工作正常。但问题是,它只对一个批处理有效,因为我使用的是next(iter)
我创建了一个TensorFlow数据集,它必须返回3个值(即X: [x,y,z]
)为我的问题。但是我只需要将x
值传递给模型。我需要将所有3个值打包在一起,因为我稍后将使用y
和z
。现在,问题是,当我要调用fit
时,我必须以某种方式分离这3个值,以便正确调用网络架构。因此,我的问题是如何在这种情况下使用PrefetchDataset
调用fit
。
import tensorflow as tf
import numpy as np
from tensorflow.keras.layers import Input, Dense, Activation, \
Conv2DTranspose, Conv2D, Reshape
from tensorflow.keras.models import Model
AUTOTUNE = tf.data.experimental.AUTOTUNE
def scale(X, a=-1, b=1, dtype='float32'):
if a > b:
a, b = b, a
xmin = tf.cast(tf.math.reduce_min(X), dtype=dtype)
xmax = tf.cast(tf.math.reduce_max(X), dtype=dtype)
X = (X - xmin) / (xmax - xmin)
scaled = X * (b - a) + a
return scaled, xmin, xmax
def set_shape_b(x, y, z):
x = tf.reshape(x, [16, 16, 2])
y = tf.reshape(y, [1])
z = tf.reshape(z, [1])
return x, y, z
def set_shape_a(x, y, z):
x = tf.reshape(x, [4, 4, 2])
y = tf.reshape(y, [1])
z = tf.reshape(z, [1])
return x, y, z
def First(lr):
inp = Input(lr)
x = Dense(16)(inp)
x = Reshape((4, 4, 16))(x)
x = Conv2DTranspose(2, kernel_size=3, strides=2, padding='same')(x)
x = Conv2DTranspose(2, kernel_size=3, strides=2, padding='same')(x)
output = Activation('tanh')(x)
model = Model(inp, output, name='First')
return model
def Second(hr):
inp = Input(hr)
x = Dense(16)(inp)
x = Conv2D(2, kernel_size=3, strides=2, padding='same')(x)
x = Conv2D(2, kernel_size=3, strides=2, padding='same')(x)
output = Dense(1, activation='sigmoid')(x)
model = Model(inputs=inp, outputs=output, name='Second')
return model
def build_model(First, Second):
inp = Input(shape=INP)
gen = First(inp)
output = Second(gen)
model = Model(inputs=inp , outputs=[gen, output], name='model')
return model
# Preproces --------------- #
a = np.random.random((20, 4, 4, 2)).astype('float32')
b = np.random.random((20, 16, 16, 2)).astype('float32')
dataset_a = tf.data.Dataset.from_tensor_slices(a)
dataset_b = tf.data.Dataset.from_tensor_slices(b)
dataset_b = dataset_b.map(lambda x: tf.py_function(scale,
[x],
(tf.float32, tf.float32, tf.float32)))
dataset_b = dataset_b.map(set_shape_b)
dataset_a = dataset_a.map(lambda x: tf.py_function(scale,
[x],
(tf.float32, tf.float32, tf.float32)))
dataset_a = dataset_a.map(set_shape_a)
dataset_ones = tf.data.Dataset.from_tensor_slices(tf.ones((len(b), 4, 4, 1)))
dataset = tf.data.Dataset.zip((dataset_a, (dataset_b, dataset_ones)))
dataset = dataset.cache()
dataset = dataset.batch(2)
dataset = dataset.prefetch(buffer_size=AUTOTUNE)
# Prepare models -------------------- #
INP = (4, 4, 2)
OUT = (16, 16, 2)
first = First(INP)
second = Second(OUT)
model = build_model(first, second)
model.compile(loss=['mse', 'binary_crossentropy'],
optimizer= tf.keras.optimizers.Adam(learning_rate=1e-4))
train_l, (train_h, train_ones) = next(iter(dataset))
# train ------------------
model.fit(train_l[0],
[train_h[0], train_ones],
epochs=2)
更新
def rescale(X_scaled, xmin, xmax):
X = (xmax - xmin) * (X_scaled + 1) / 2.0 + xmin
return X
class PlotCallback(tf.keras.callbacks.Callback):
def __init__(self, image, xmin, xmax, model):
self.image = image
self.xmin = xmin
self.xmax = xmax
self.model = model
def on_epoch_end(self, epoch, logs={}):
preds = self.model.predict(self.image)
y_pred = preds[0]
y_pred = rescale(y_pred, self.xmin, self.xmax)
fig, ax = plt.subplots(figsize=(14, 10))
ax.imshow(y_pred[0][:, :, 0])
plt.close()
我正在使用上述函数,当试图适合我想要的东西:
model.fit(
dataset,
validation_data=dataset,
epochs=2,
callbacks=[PlotCallback(here_the_dataset_a_scaled_values,
xmin_from_dataset_a,
xmax_from_dataset_b, model)]
)
1条答案
按热度按时间s71maibg1#
按照上面的注解,要解决您的问题,您可以应用自定义函数来只返回目标值。另外,请查看tf.data.Dataset.map以获取参考资料。
压缩和批处理数据。
现在,我们可以将其传递给fit方法。
更新1
正如评论中提到的,不能使用
.map(only_scale)
来接收(scale, xmin, xmax)
,以便在训练过程中进行解缩放。但是我们不能将这样的数据格式传递给不期望这样的输入规范的模型。换句话说,模型代码不知道xmin
和xmax
。在这种情况下,有两个选项来解决它。一个是在keras中使用自定义训练循环,另一个是覆盖
fit
方法的train_step
函数。我们试试第二个。在这种情况下,我们不需要使用来自数据API的.map(only_scale)
方法。下面是关于覆盖fit方法的references。让我们构建一个自定义模型来覆盖
trian_step
和(还有test_step
用于验证数据)。还有predict_step
。接下来,我们可以做
接下来,我们现在可以拟合datast(不使用
only_scale
方法)。更新2
关于在回调中使用
xmin
和xmax
来重新缩放预测数组和绘图,我们可以做一些事情如下。1.我们将在训练时存储
xmin
和xmax
的值。我们现在将从验证数据集中存储这些值。1.稍后在回调中,我们在
on_epoch_end
处使用此值,并在on_epoch_begin
处重置以用于下一个epoch。首先,我们会做:
现在,在回调中,我们将
接下来,我们可以调用这个回调。注意,我们传递的是2D单输入。如果
PlotCallback(dataset)
通过,请确保实现predict_step
,这将与上面模型代码中的test_step
几乎相同。更新3
正如你在评论中提到的,最初显示的日志名称是
First_loss
,Second_loss
,在更新1/2之后,它变成了output_1_loss
和output_2_loss
。为了解决这个问题,我们可以稍微修改一下模型代码。首先,我们要接下来,我们对
CustomFitter
执行以下操作,删除init和call方法,不再需要。