我正在尝试使用TensorFlow 2.9.2训练模型。我的模型定义为
import tensorflow as tf
encoder_layers = 1
encoder_bidirectional = False
def get_model():
model = tf.keras.Sequential(name='model')
model.add(tf.keras.layers.Dropout(0.5))
for _ in range(encoder_layers):
rnn = tf.keras.layers.LSTM(2**6, return_sequences=True)
if encoder_bidirectional:
rnn = tf.keras.layers.Bidirectional(rnn)
model.add(rnn)
model.add(tf.keras.layers.Dense(2, activation='softmax'))
return model
def build_model():
model = get_model()
model.build(input_shape=(None, None, 25))
model.compile(
loss='sparse_categorical_crossentropy',
optimizer=tf.keras.optimizers.Adam(0.001),
metrics=['accuracy']
)
model.summary()
return model
然后我使用
# train model
train, dev, test = get_datasets()
model = build_model()
es = EarlyStopping(
monitor='val_accuracy',
mode='max',
verbose=1,
patience=10)
mc = ModelCheckpoint(
'model.h5',
monitor='val_accuracy',
mode='max',
verbose=1,
save_best_only=True)
with tf.device("/GPU:0"):
model.fit(
train,
epochs=500,
steps_per_epoch=32,
validation_data=dev,
callbacks=[es, mc])
best_model = load_model('model.h5')
best_model.evaluate(test)
在best_model = load_model('model.h5')
处,我得到以下错误
Traceback (most recent call last):
File "/usr/lib/python3.8/runpy.py", line 194, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/usr/lib/python3.8/runpy.py", line 87, in _run_code
exec(code, run_globals)
File "/experiments/train.py", line 76, in <module>
app.run(main)
File "/usr/local/lib/python3.8/dist-packages/absl/app.py", line 308, in run
_run_main(main, args)
File "/usr/local/lib/python3.8/dist-packages/absl/app.py", line 254, in _run_main
sys.exit(main(argv))
File "/experiments/train.py", line 70, in main
best_model = load_model(FLAGS.model_path)
File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 67, in error_handler
raise e.with_traceback(filtered_tb) from None
File "/usr/local/lib/python3.8/dist-packages/keras/initializers/initializers_v2.py", line 1056, in _compute_fans
return int(fan_in), int(fan_out)
TypeError: int() argument must be a string, a bytes-like object or a number, not 'NoneType'
找到this post后,我检查了我的model.h5
文件,实际上它有batch_input_shape=[null,null,null]
。但是,如何防止检查点模型与输入形状的空值一起保存?有什么办法可以解决这个问题吗?
编辑:
我只是用这个colab中的一个数据样本来复制错误:https://colab.research.google.com/drive/1z63TN-P_WKtTWTZs2IhGBU0NjD7TE6m_#scrollTo=f1oD4G6QEq4k。
1条答案
按热度按时间5fjcxozz1#
在模型代码中,在开头添加以下层。