unilm 截断的输出段落

v09wglhw  于 4个月前  发布在  其他
关注(0)|答案(4)|浏览(74)

嘿,大家好,

我在使用UniLM V1时遇到了一个问题,该模型已经针对案例纠正任务进行了微调。案例纠正问题是模型需要找出哪些单词需要大写。模型的输入是所有大写字母开头的单词(单词首字母大写),输出则是正确的大小写形式。

然而,预测结果被截断了。我在摘要生成和问题生成任务中使用了这个模型;但是,当max_seq_lengthmax_pred设置得相当低时,我观察到了截断现象。在这种情况下,输入和输出的最大标记数为70。

我分别使用了以下代码片段进行微调和预测。
微调:

python biunilm/run_seq2seq.py --fp16 --amp --do_train --num_workers 0 \
--bert_model bert-large-cased --new_segment_ids \
--data_dir ${DATA_DIR} --src_file train.src.txt --tgt_file train.tgt.txt \
--output_dir ${OUTPUT_DIR}/bert_save \
  --log_dir ${OUTPUT_DIR}/bert_log \
  --model_recover_path ${MODEL_RECOVER_PATH} \
  --max_seq_length 200 --max_position_embeddings 200 \
  --mask_prob 0.7 --max_pred 150 \
  --train_batch_size 128 --gradient_accumulation_steps 6 \
  --learning_rate 0.00002 --warmup_proportion 0.1 --label_smoothing 0.1 \
  --num_train_epochs 10

预测:

python biunilm/decode_seq2seq.py --fp16 --amp --bert_model bert-large-cased \
--new_segment_ids --mode s2s \
  --input_file ${DATA_DIR}/${EVAL_SPLIT}.src.txt --split ${EVAL_SPLIT} \
  --model_recover_path ${MODEL_RECOVER_PATH} \
  --max_seq_length 200 --max_tgt_length 150 \
  --batch_size 128 --beam_size 5 --length_penalty 0 \
  --not_predict_token "[UNK]" --forbid_duplicate_ngrams --forbid_ignore_word "."

然而,我观察到大约有10%的情况,输出会被截断,如下所示:
模型输出:
Masked against coronavirus, Shanghai's ballet dancers pirouette and
正确输出:
Masked against coronavirus, Shanghai's ballet dancers pirouette and plie
任何帮助都将不胜感激,谢谢!

gmol1639

gmol16391#

你好,@aretius ,

我们之前没有遇到过这个问题。在训练数据中,是否存在一些被截断的示例?对于纠正任务,因为我们可以知道输入和输出之间的对齐关系,所以我们可以限制候选预测和长度。模型可以被迫预测所需的令牌数量。另一个解决方案是尝试新发布的用于s2s微调的包https://github.com/microsoft/unilm/tree/master/s2s-ft

n53p2ov0

n53p2ov02#

感谢您的回复 @donglixp !
是的,这确实很奇怪。我还附上了输入中token长度的分布。

刚刚有几个后续问题

  • 模型是否有可能偏向于产生与输入平均token长度相等的输出(因为数据集大小约为8K个数据点)?
  • 由于输入中的所有单词都是大小写敏感的,因此输入和正确输出中的token数量不同。所以我不确定我们如何约束候选预测和长度。您能否帮助我在仓库中的代码片段中实现相同的功能?
  • s2s-ft与UniLM v1代码片段有何不同?
pepwfjgg

pepwfjgg3#

@donglixp 我观察到使用更大的数据集可以降低截断率至仅2%。当数据集较小时,您能提供一些关于运行的周期的信息吗?

jaxagkaj

jaxagkaj4#

对于纠正任务,因为我们可以知道输出是否完全解码,所以我们可以在所有期望的标记都被发射之前禁止结束序列标记。例如,如果我们发现当前预测是 Masked against coronavirus, Shanghai's ballet dancers pirouette and(即单词 plie 缺失),我们可以为分配一个较大的负分数,以便输出不会被截断。

相关问题