CTranslate2 如何使用并行编码器转换模型?

bqujaahr  于 2个月前  发布在  其他
关注(0)|答案(1)|浏览(35)

框架:OpenNMT-tf
模型:

class TinyDualSourceTransformer(onmt.models.Transformer):

    def __init__(self):
        super(TinyDualSourceTransformer, self).__init__(
            source_inputter=onmt.inputters.ParallelInputter([
                onmt.inputters.WordEmbedder(embedding_size=256),
                onmt.inputters.WordEmbedder(embedding_size=256)]),
            target_inputter=onmt.inputters.WordEmbedder(embedding_size=256),
            num_layers=4,
            num_units=128,
            num_heads=4,
            ffn_inner_dim=512,
            dropout=0.1,
            attention_dropout=0.1,
            ffn_dropout=0.1,
            share_encoders=True)

    def auto_config(self, num_replicas=1):
        config = super(TinyDualSourceTransformer, self).auto_config(num_replicas=num_replicas)
        max_length = config["train"]["maximum_features_length"]
        return misc.merge_dict(config, {
            "train": {
                "maximum_features_length": [max_length, max_length]
            }
        })

model = TinyDualSourceTransformer

命令:

onmt-main --config data_tiny_0504.yml --auto_config export --output_dir models/model_v2_ctrans --format ctranslate2

错误:

2023-05-06 15:26:36.878000: I inputter.py:316] Initialized source_1 input layer:
2023-05-06 15:26:36.878000: I inputter.py:316]  - vocabulary size: 50001
2023-05-06 15:26:36.878000: I inputter.py:316]  - special tokens: BOS=no, EOS=no
2023-05-06 15:26:36.904000: I inputter.py:316] Initialized source_2 input layer:
2023-05-06 15:26:36.904000: I inputter.py:316]  - vocabulary size: 4121
2023-05-06 15:26:36.904000: I inputter.py:316]  - special tokens: BOS=no, EOS=no
2023-05-06 15:26:37.075000: I inputter.py:316] Initialized target input layer:
2023-05-06 15:26:37.075000: I inputter.py:316]  - vocabulary size: 50001
2023-05-06 15:26:37.075000: I inputter.py:316]  - special tokens: BOS=yes, EOS=yes
2023-05-06 15:26:37.139000: I runner.py:490] Restored checkpoint run_0504/ckpt-25000
2023-05-06 15:26:38.775468: I tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:428] Loaded cuDNN version 8201
Traceback (most recent call last):
  File "/ldap_home/kyro.wang/miniconda3/envs/nmt-tf/bin/onmt-main", line 8, in <module>
    sys.exit(main())
  File "/ldap_home/kyro.wang/miniconda3/envs/nmt-tf/lib/python3.9/site-packages/opennmt/bin/main.py", line 347, in main
    runner.export(
  File "/ldap_home/kyro.wang/miniconda3/envs/nmt-tf/lib/python3.9/site-packages/opennmt/runner.py", line 496, in export
    model.export(export_dir, exporter=exporter)
  File "/ldap_home/kyro.wang/miniconda3/envs/nmt-tf/lib/python3.9/site-packages/opennmt/models/model.py", line 439, in export
    exporter.export(self, export_dir)
  File "/ldap_home/kyro.wang/miniconda3/envs/nmt-tf/lib/python3.9/site-packages/opennmt/utils/exporters.py", line 22, in export
    self._export_model(model, export_dir)
  File "/ldap_home/kyro.wang/miniconda3/envs/nmt-tf/lib/python3.9/site-packages/opennmt/utils/exporters.py", line 169, in _export_model
    converter.convert(export_dir, quantization=self._quantization, force=True)
  File "/ldap_home/kyro.wang/miniconda3/envs/nmt-tf/lib/python3.9/site-packages/ctranslate2/converters/converter.py", line 89, in convert
    model_spec = self._load()
  File "/ldap_home/kyro.wang/miniconda3/envs/nmt-tf/lib/python3.9/site-packages/ctranslate2/converters/opennmt_tf.py", line 88, in _load
    return spec_builder(self._model)
  File "/ldap_home/kyro.wang/miniconda3/envs/nmt-tf/lib/python3.9/site-packages/ctranslate2/converters/opennmt_tf.py", line 113, in __call__
    check.validate()
  File "/ldap_home/kyro.wang/miniconda3/envs/nmt-tf/lib/python3.9/site-packages/ctranslate2/converters/utils.py", line 87, in validate
    raise_unsupported(self._unsupported_reasons)
  File "/ldap_home/kyro.wang/miniconda3/envs/nmt-tf/lib/python3.9/site-packages/ctranslate2/converters/utils.py", line 74, in raise_unsupported
    raise ValueError(message)
ValueError: The model you are trying to convert is not supported by CTranslate2. We identified the following reasons:

- Parallel encoders are not supported
wpx232ag

wpx232ag1#

错误信息非常明确。这种架构不受支持,因此目前没有办法将具有多个编码器的模型进行转换。

相关问题