我正在尝试对现有的拥抱脸模型进行微调。
下面的代码是我从一些文档中收集的
from transformers import AutoTokenizer, AutoModelForQuestionAnswering, TrainingArguments, Trainer
import torch
# Load the Vietnamese model and tokenizer
model_name = "vinai/phobert-base"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForQuestionAnswering.from_pretrained(model_name)
# Define the training data
train_dataset = [
{
"question": "What is your name ?",
"context": "My name is Peter",
"answer": {
"text": "Peter",
"start": 7,
"end": 11
}
}
]
# Define the validation data
val_dataset = [
{
"question": "What is your name ?",
"context": "My name is Peter",
"answer": {
"text": "Peter",
"start": 7,
"end": 11
}
}
]
# Define the training arguments
training_args = TrainingArguments(
output_dir='./results',
evaluation_strategy='epoch',
learning_rate=2e-5,
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
num_train_epochs=3,
weight_decay=0.01,
)
# Define the data collator
def collate_fn(data):
input_ids = torch.stack([item.get('input_ids', None) for item in data if 'input_ids' in item])
attention_mask = torch.stack([item.get('attention_mask', None) for item in data if 'attention_mask' in item])
start_positions = torch.stack([item.get('start_positions', None) for item in data if 'start_positions' in item])
end_positions = torch.stack([item.get('end_positions', None) for item in data if 'end_positions' in item])
return {
'input_ids': input_ids,
'attention_mask': attention_mask,
'start_positions': start_positions,
'end_positions': end_positions
}
# Define the trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=val_dataset,
data_collator=collate_fn
)
# Fine-tune the model
trainer.train()
我一直收到错误的
input_ids = torch.stack([item.get('input_ids', None) for item in data if 'input_ids' in item])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: stack expects a non-empty TensorList
我试着去做
def collate_fn(data):
print(data)
但我得到了[]
1条答案
按热度按时间kb5ga3dv1#
在
train_dataset
上只有一个示例,因此请尝试将批大小设置为1。