CTranslate2 多查询架构损坏 - OpenNMT-py - score_batch

xyhw6mcr  于 2个月前  发布在  其他
关注(0)|答案(1)|浏览(37)

当我在刚刚转换的模型架构上使用 score_batch 时,我遇到了一个 ValueError。我正在使用 score_batch 函数过滤一些翻译数据,以便使用更好的数据继续训练模型。但是,每当该函数在转换后的模型检查点的 Translator 示例上调用时,它都会抛出一个形状错误。x 和 y 似乎根据每个单独的批次而变化,因为当我尝试仅在单个示例上使用 score_batch 时,没有抛出错误。

我已经检查了传入的编码源/目标的结构/值,似乎都处理得很好。传递句子批次为 4096-6048,最大批次大小为 2048 个标记。对于具有相同代码的其他架构模型(没有进行任何更改),此函数的处理工作正常。使用 translate_batch 进行翻译也没有任何问题,所以我不确定发生了什么。

请告诉我是否转换后的模型 / opennmt-py 检查点会有所帮助,或者我可以提供什么帮助。

zqdjd7g9

zqdjd7g91#

进行了一些测试,迄今为止只发现它在一次处理一批数据时才能正常工作。
脚本中的max_size似乎也缩短了实际的built_batches,因此在测试中max_size的顺序是递减的。这里是:

import ctranslate2
import sentencepiece

pivot_lang = "en"
langs = ["tr"]
base_folder = "C:/Machine_Learning/NLP"
lang_pair = "middle_east"

ct2_model = ctranslate2.Translator(
        f"{base_folder}/models/{lang_pair}/ct2",
        device="cuda",
        compute_type="int8",
    )
source_sentence = sentencepiece.SentencePieceProcessor(
        f"{base_folder}/models/{lang_pair}/general_multi.model"
    )
target_sentence = sentencepiece.SentencePieceProcessor(
        f"{base_folder}/models/{lang_pair}/general_multi.model"
    )

def encode(src_sents, tgt_sents, tgt_lang, src_lang):
    if type(src_sents) == str:
        src_sents = [src_sents]
    if type(tgt_sents) == str:
        tgt_sents = [tgt_sents]

    src_sents = source_sentence.Encode(src_sents, out_type=str)

    tgt_sents = target_sentence.Encode(tgt_sents, out_type=str)

    return src_sents, tgt_sents

class Test():
    def __init__(self, src_lang, tgt_lang, batch_size, batch, batch_type = "tokens", max_size = 2048):

        self.src_lang = src_lang
        self.tgt_lang = tgt_lang
        self.batch_size = batch_size
        self.batch_type = batch_type
        self.batch = batch
        for x in range(len(self.batch)):
            new_x = [self.batch[x][y] for y in range(min(max_size, len(self.batch[x])))]
            self.batch[x] = new_x

    def run_test(self):
        src_token, tgt_token = encode(self.batch[0], self.batch[1], self.tgt_lang, self.src_lang)
        assert len(src_token) == len(tgt_token), "Source and target example count must be equal"
        src_token_amount = sum(len(x) for x in src_token)
        tgt_token_amount = sum(len(x) for x in tgt_token)
        try:
            results = ct2_model.score_batch(source=src_token, target=tgt_token, max_batch_size=self.batch_size, batch_type=self.batch_type)
        except Exception as e:

            print(f"----\nTest with {self.src_lang}-{self.tgt_lang} failed\nBatch size: {self.batch_size} of {self.batch_type}\nTotal tokens: src - {src_token_amount} tgt - {tgt_token_amount} \
\nExample count: {len(src_token)}\nERROR OUTPUT:", e)

            return
        simplified_perps = [x.log_probs for x in results]
        simplified_perps = [
            sum(abs(x) ** 2 for x in y) / len(y)
            for y in simplified_perps
        ]
        simplified_perps = sum(simplified_perps) / len(simplified_perps)


        print(f"----\nTEST WITH {self.src_lang}-{self.tgt_lang} PASSED\nBatch size: {self.batch_size} of {self.batch_type}\nTotal tokens: src - {src_token_amount}, tgt - {tgt_token_amount} \
\nExample count: {len(src_token)}\nPerplexity: {simplified_perps}\n----")

src, tgt = pivot_lang, langs[0]
source_file = "C:/TranslationData/flores200_dataset/dev/eng_Latn.dev"
target_file = "C:/TranslationData/flores200_dataset/dev/tur_Latn.dev"

built_batches = []
with (open(source_file, encoding="utf8") as src_file, open(target_file, encoding="utf8") as tgt_file):
    original = [line.replace("\n", "") for line in src_file.readlines()]
    references = [line.replace("\n", "") for line in tgt_file.readlines()]
    built_batches = [original, references]

Test(src, tgt, 2048, built_batches).run_test()
Test(src, tgt, 1000, built_batches, "examples").run_test()
Test(src, tgt, 4096, built_batches).run_test()
Test(src, tgt, 100, built_batches, "examples").run_test()
Test(src, tgt, 1, built_batches, "examples", ).run_test()
Test(src, tgt, 100, built_batches, "examples", 100).run_test()
Test(src, tgt, 2048, built_batches, "tokens", 10).run_test()
Test(src, tgt, 2048, built_batches, "examples", 2).run_test()
Test(src, tgt, 2, built_batches, "examples", 2).run_test()
Test(src, tgt, 2, built_batches, "examples", 1).run_test()

相关问题