CTranslate2 支持UMT5

xghobddn  于 2个月前  发布在  其他
关注(0)|答案(1)|浏览(37)

来自谷歌的新的UMT5模型是目前最有趣的原始T5s变体。然而,尝试使用transformers转换器将UMT5模型进行转换,运行以下命令:
ct2-transformers-converter --model google/umt5-xl --output_dir ct2-umt5-3b --quantization int8
得到的结果是:

Downloading (…)lve/main/config.json: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 812/812 [00:00<00:00, 4.96MB/s]
Traceback (most recent call last):
  File "/home/user/miniconda3/bin/ct2-transformers-converter", line 8, in <module>
    sys.exit(main())
             ^^^^^^
  File "/home/user/miniconda3/lib/python3.11/site-packages/ctranslate2/converters/transformers.py", line 1719, in main
    converter.convert_from_args(args)
  File "/home/user/miniconda3/lib/python3.11/site-packages/ctranslate2/converters/converter.py", line 50, in convert_from_args
    return self.convert(
           ^^^^^^^^^^^^^
  File "/home/user/miniconda3/lib/python3.11/site-packages/ctranslate2/converters/converter.py", line 89, in convert
    model_spec = self._load()
                 ^^^^^^^^^^^^
  File "/home/user/miniconda3/lib/python3.11/site-packages/ctranslate2/converters/transformers.py", line 106, in _load
    raise ValueError(
ValueError: No conversion is registered for the model configuration UMT5Config (supported configurations are: BartConfig, BertConfig, BloomConfig, CodeGenConfig, DistilBertConfig, FalconConfig, GPT2Config, GPTBigCodeConfig, GPTJConfig, GPTNeoXConfig, LlamaConfig, M2M100Config, MBartConfig, MPTConfig, MT5Config, MarianConfig, OPTConfig, PegasusConfig, RWConfig, T5Config, WhisperConfig, XLMRobertaConfig)

有没有简单的解决方法?这是否应该添加到包中?

az31mfrm

az31mfrm1#

UMT5模型的每个自注意力层具有独特的相对注意力偏置。因此,在the file中的相应转换器代码可以写成如下形式:

@register_loader("UMT5Config")
class UMT5Loader(T5Loader):
@property
def architecture_name(self):
return "UMT5ForConditionalGeneration"

def set_stack(self, spec, module, is_decoder=False):
    self.set_layer_norm(spec.layer_norm, module.final_layer_norm)
    self.set_embeddings(
        spec.embeddings[0] if isinstance(spec.embeddings, list) else spec.embeddings,
        module.embed_tokens,
    )

    spec.scale_embeddings = False

    for layer_spec, block in zip(spec.layer, module.block):
        self.set_self_attention(layer_spec.self_attention, block.layer[0])

        if is_decoder:
            self.set_cross_attention(layer_spec.attention, block.layer[1])

        self.set_ffn(layer_spec.ffn, block.layer[-1])

此外,在计算注意力时,除了第一层之外,position_bias会在各层之间重用,而不是使用相对注意力偏置重新计算。具体代码位于CTranslate2/src/layers/attention.cc的第236行至第253行。

要为每一层使用正确的position_bias,只需禁用下面的if条件即可:

// if (position_bias->empty()) {
const dim_t query_length = queries.dim(2);
const dim_t key_length = keys.dim(2);
*position_bias = compute_relative_bias(*relative_attention_bias,
query_length,
key_length,
maximum_relative_position,
is_decoder,
with_cache ? key_length - 1 : 0);
// }

然而,这种方法可能会导致T5和MT5模型的性能下降。

相关问题