pytorch 我的`collate_fn`函数传递给Trainer函数的collate_fn参数时得到空数据

iq0todco  于 2023-04-21  发布在  其他
关注(0)|答案(1)|浏览(155)

我正在尝试对现有的拥抱脸模型进行微调。
下面的代码是我从一些文档中收集的

from transformers import AutoTokenizer, AutoModelForQuestionAnswering, TrainingArguments, Trainer
import torch

# Load the Vietnamese model and tokenizer
model_name = "vinai/phobert-base"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForQuestionAnswering.from_pretrained(model_name)

# Define the training data
train_dataset = [
    {
        "question": "What is your name ?",
        "context": "My name is Peter",
        "answer": {
            "text": "Peter",
            "start": 7,
            "end": 11
        }
    }
]

# Define the validation data
val_dataset = [
    {
        "question": "What is your name ?",
        "context": "My name is Peter",
        "answer": {
            "text": "Peter",
            "start": 7,
            "end": 11
        }
    }
]

# Define the training arguments
training_args = TrainingArguments(
    output_dir='./results',
    evaluation_strategy='epoch',
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=3,
    weight_decay=0.01,
)

# Define the data collator
def collate_fn(data):
    input_ids = torch.stack([item.get('input_ids', None) for item in data if 'input_ids' in item])
    attention_mask = torch.stack([item.get('attention_mask', None) for item in data if 'attention_mask' in item])
    start_positions = torch.stack([item.get('start_positions', None) for item in data if 'start_positions' in item])
    end_positions = torch.stack([item.get('end_positions', None) for item in data if 'end_positions' in item])
    return {
        'input_ids': input_ids,
        'attention_mask': attention_mask,
        'start_positions': start_positions,
        'end_positions': end_positions
    }

# Define the trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    data_collator=collate_fn
)

# Fine-tune the model
trainer.train()

我一直收到错误的

input_ids = torch.stack([item.get('input_ids', None) for item in data if 'input_ids' in item])
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: stack expects a non-empty TensorList

我试着去做

def collate_fn(data):
    print(data)

但我得到了[]

kb5ga3dv

kb5ga3dv1#

train_dataset上只有一个示例,因此请尝试将批大小设置为1。

相关问题