pytorch 有没有一种方法可以将张纸板SummaryWriter与HuggingFace TrainerAPI结合使用?

ttisahbt  于 2022-11-09  发布在  其他
关注(0)|答案(2)|浏览(270)

我正在使用HF Seq 2SeqTrainingArguments和Seq 2SeqTrainer微调HuggingFace变压器模型(PyTorch版本),并且我希望在Tensorboard中显示训练和验证损耗(在同一图表中)。
据我所知,为了将两个损失绘制在一起,我需要使用SummaryWriter。HF Callbacks文档描述了一个可以接收tb_writer参数的TensorBoardCallback函数:
https://huggingface.co/docs/transformers/v4.21.1/en/main_classes/callback#transformers.integrations.TensorBoardCallback
然而,我不能弄清楚什么是正确的使用方式,如果它甚至应该与教练API一起使用的话。
我的代码看起来像这样:

args = Seq2SeqTrainingArguments(
    output_dir=output_dir,
    evaluation_strategy='epoch',
    learning_rate= 1e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=num_train_epochs,
    predict_with_generate=True,
    logging_steps=logging_steps,
    report_to='tensorboard',
    push_to_hub=False,  
)

trainer = Seq2SeqTrainer(
    model,
    args,
    train_dataset=tokenized_train_data,
    eval_dataset=tokenized_val_data,
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

我会假设我应该在训练器中包含对TensorBoard的回调,例如,

callbacks = [TensorBoardCallback(tb_writer=tb_writer)]

但是我找不到一个关于如何使用/导入什么来使用它全面示例。
我还在GitHub上找到了这个功能请求,
https://github.com/huggingface/transformers/pull/4020
但没有使用的例子,所以我很困惑...
任何见解都将受到赞赏

9o685dep

9o685dep1#

它非常简单。您可以在“Seq2SeqTrainingArguments”中提到它。没有必要在“Seq2SeqTrainer”函数中显式定义它。

model_arguments = Seq2SeqTrainingArguments(output_dir= "./best_model/",
                                        num_train_epochs = EPOCHS, 
                                        overwrite_output_dir= True, 
                                        do_train= True, 
                                        do_eval= True, 
                                        do_predict= True, 
                                        auto_find_batch_size= True, 
                                        evaluation_strategy = 'epoch',
                                        warmup_steps = 10000, 
                                        logging_dir = "./log_files/", 
                                        disable_tqdm = False, 
                                        load_best_model_at_end = True, 
                                        save_strategy= 'epoch', 
                                        save_total_limit = 1, 
                                        per_device_eval_batch_size= BATCH_SIZE, 
                                        per_device_train_batch_size= BATCH_SIZE, 
                                        predict_with_generate=True, 
                                        report_to='wandb',
                                        run_name="rober_based_encoder_decoder_text_summarisation"

                                        )

同时,您还可以进行其他回调:

early_stopping = EarlyStoppingCallback(early_stopping_patience= 5, 
                                    early_stopping_threshold= 0.001)

然后,通过trainer参数将参数和回调作为列表传递:

trainer = Seq2SeqTrainer(model = model, 
                        compute_metrics= compute_metrics,
                        args= model_arguments, 
                        train_dataset= Train, 
                        eval_dataset= Val, 
                        tokenizer=tokenizer, 
                        callbacks= [early_stopping, ]
                        )

训练模型。确保在训练前登录wandb

trainer.train()
tvmytwxo

tvmytwxo2#

Trainer类自动输出TensorBoard的事件,不需要使用回调函数。
正如@Junaid所提到的,日志记录可以由TrainingArguments类控制,例如,您可以在那里设置logging_dir
如需详细信息,请参阅TrainingArguments。

相关问题