bert TPU训练tf_hub模型崩溃

dffbzjpn  于 3个月前  发布在  其他
关注(0)|答案(1)|浏览(51)

在使用tf_hub BERT模型进行训练时,使用TPU经常会出现崩溃的情况。

根据我的数据集,有时可以正常运行,有时则不行,这取决于--save_checkpoints_steps参数。

要重现这个问题,请使用"Predicting Movie Reviews with BERT on TF Hub"的colab笔记本,选择TPU运行时,使用某个存储桶来存储数据,使用500个训练/测试示例(仅为了加快速度),并将run_config单元格替换为:

然后运行训练单元格,它应该会因为上述错误而崩溃。

完整的单元格输出如下:

dxpyg8gm

dxpyg8gm1#

这个问题是由于在使用 TensorFlow Hub 时,文件系统方案 '[local]' 没有实现导致的。你可以尝试将本地文件系统中的模型文件转换为 HDF5 格式,然后在 TensorFlow Hub 中使用 HDF5 格式的模型文件。以下是将模型文件转换为 HDF5 格式的方法:

  1. 首先,确保已经安装了 tensorflowh5py 库。如果没有安装,可以使用以下命令进行安装:
pip install tensorflow h5py
  1. 然后,使用以下代码将模型文件(例如 model.pb)转换为 HDF5 格式(例如 model.h5):
import tensorflow as tf
from tensorflow import keras
from h5py import File

# 加载模型

model = keras.models.load_model('model.pb')

# 将模型保存为 HDF5 格式

model.save('model.h5')
  1. 最后,在 TensorFlow Hub 中使用转换后的 HDF5 格式的模型文件。
is_predicting, input_ids, input_mask, segment_ids, label_ids, num_labels)
  File "", line 7, in create_model
  trainable=True)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_hub/module.py", line 170, in **init**
  tags=self._tags)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_hub/native_module.py", line 340, in _create_impl
  name=name)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_hub/native_module.py", line 399, in **init**
  self._init_state(name)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_hub/native_module.py", line 407, in _init_state
  self._variable_map)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/training/checkpoint_utils.py", line 291, in init_from_checkpoint
  init_from_checkpoint_fn)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/distribute_lib.py", line 1684, in merge_call
  return self._merge_call(merge_fn, args, kwargs)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/distribute/distribute_lib.py", line 1691, in _merge_call
  return merge_fn(self._strategy, *args, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/training/checkpoint_utils.py", line 286, in 
  ckpt_dir_or_file, assignment_map)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/training/checkpoint_utils.py", line 334, in _init_from_checkpoint
  _set_variable_or_list_initializer(var, ckpt_file, tensor_name_in_ckpt)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/training/checkpoint_utils.py", line 458, in _set_variable_or_list_initializer
  _set_checkpoint_initializer(variable_or_list, ckpt_file, tensor_name, "")
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/training/checkpoint_utils.py", line 412, in _set_checkpoint_initializer
  ckpt_file, [tensor_name], [slice_spec], [base_type], name=name)[0]
  File "/usr/local/lib/python3.6

相关问题