嘿,大家好,
我在使用UniLM V1时遇到了一个问题,该模型已经针对案例纠正任务进行了微调。案例纠正问题是模型需要找出哪些单词需要大写。模型的输入是所有大写字母开头的单词(单词首字母大写),输出则是正确的大小写形式。
然而,预测结果被截断了。我在摘要生成和问题生成任务中使用了这个模型;但是,当max_seq_length
或max_pred
设置得相当低时,我观察到了截断现象。在这种情况下,输入和输出的最大标记数为70。
我分别使用了以下代码片段进行微调和预测。
微调:
python biunilm/run_seq2seq.py --fp16 --amp --do_train --num_workers 0 \
--bert_model bert-large-cased --new_segment_ids \
--data_dir ${DATA_DIR} --src_file train.src.txt --tgt_file train.tgt.txt \
--output_dir ${OUTPUT_DIR}/bert_save \
--log_dir ${OUTPUT_DIR}/bert_log \
--model_recover_path ${MODEL_RECOVER_PATH} \
--max_seq_length 200 --max_position_embeddings 200 \
--mask_prob 0.7 --max_pred 150 \
--train_batch_size 128 --gradient_accumulation_steps 6 \
--learning_rate 0.00002 --warmup_proportion 0.1 --label_smoothing 0.1 \
--num_train_epochs 10
预测:
python biunilm/decode_seq2seq.py --fp16 --amp --bert_model bert-large-cased \
--new_segment_ids --mode s2s \
--input_file ${DATA_DIR}/${EVAL_SPLIT}.src.txt --split ${EVAL_SPLIT} \
--model_recover_path ${MODEL_RECOVER_PATH} \
--max_seq_length 200 --max_tgt_length 150 \
--batch_size 128 --beam_size 5 --length_penalty 0 \
--not_predict_token "[UNK]" --forbid_duplicate_ngrams --forbid_ignore_word "."
然而,我观察到大约有10%的情况,输出会被截断,如下所示:
模型输出:
Masked against coronavirus, Shanghai's ballet dancers pirouette and
正确输出:
Masked against coronavirus, Shanghai's ballet dancers pirouette and plie
任何帮助都将不胜感激,谢谢!
4条答案
按热度按时间gmol16391#
你好,@aretius ,
我们之前没有遇到过这个问题。在训练数据中,是否存在一些被截断的示例?对于纠正任务,因为我们可以知道输入和输出之间的对齐关系,所以我们可以限制候选预测和长度。模型可以被迫预测所需的令牌数量。另一个解决方案是尝试新发布的用于s2s微调的包https://github.com/microsoft/unilm/tree/master/s2s-ft。
n53p2ov02#
感谢您的回复 @donglixp !
是的,这确实很奇怪。我还附上了输入中token长度的分布。
刚刚有几个后续问题
pepwfjgg3#
@donglixp 我观察到使用更大的数据集可以降低截断率至仅2%。当数据集较小时,您能提供一些关于运行的周期的信息吗?
jaxagkaj4#
对于纠正任务,因为我们可以知道输出是否完全解码,所以我们可以在所有期望的标记都被发射之前禁止结束序列标记。例如,如果我们发现当前预测是
Masked against coronavirus, Shanghai's ballet dancers pirouette and
(即单词plie
缺失),我们可以为分配一个较大的负分数,以便输出不会被截断。