pytorch 用于Huggingface超参数调整的“collate_fn”

5fjcxozz  于 2023-03-12  发布在  其他
关注(0)|答案(1)|浏览(165)

我正在学习this教程,学习如何使用Huggingface和Wandb进行超参数调优。大部分功能都能正常工作,但我不太明白“collate_fn”函数在做什么,以及如何根据我的用例调整它。我的数据集如下所示:str类型的“text”列,包含tweet的内容,还有一个int值的“majority_votes”。

import wandb
wandb.login()

%env WANDB_PROJECT=vit_snacks_sweeps
%env WANDB_LOG_MODEL=true
%env WANDD_NOTEBOOK_NAME=Trainer_Huggingface

from transformers import BertTokenizer, BertForSequenceClassification
import numpy as np
from sklearn.model_selection import train_test_split

with s3.open(f"{bucket_name}/KFOLD1/{train_file_name}",'r') as file:
    data = pd.read_csv(file)
with s3.open(f"{bucket_name}/KFOLD1/{test_file_name}",'r') as file:
    test_data = pd.read_csv(file)
data = data[["Text", "majority_vote"]]
test_data = test_data[["Text", "majority_vote"]]
data.rename(columns={'Text': 'text', 'majority_vote': 'labels'}, inplace=True)
test_data.rename(columns={'Text': 'text', 'majority_vote': 'labels'}, inplace=True)

    # Define pre trained tokenizer and model
tokenizer = BertTokenizer.from_pretrained(model_name)
#model = BertForSequenceClassification.from_pretrained(model_name, num_labels=2)

    # ----- 1. Preprocess data -----#
    # Preprocess data
X = list(data["text"])
y = list(data["labels"])
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=11)
X_train_tokenized = tokenizer(X_train, padding=True, truncation=True, max_length=512)
X_val_tokenized = tokenizer(X_val, padding=True, truncation=True, max_length=512)

    # Create torch dataset
class Dataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels=None):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        if self.labels:
            item["labels"] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.encodings["input_ids"])

train_dataset = Dataset(X_train_tokenized, y_train)
val_dataset = Dataset(X_val_tokenized, y_val)
   

def model_init():
    vit_model = model = BertForSequenceClassification.from_pretrained(model_name, num_labels=2)
    return vit_model

# method
sweep_config = {
    'method': 'random'
}

# hyperparameters
parameters_dict = {
    'epochs': {
        'value': 1
        },
    'batch_size': {
        'values': [8, 16, 32]
        },
    'learning_rate': {
        'distribution': 'log_uniform_values',
        'min': 1e-5,
        'max': 1e-3
    },
    'weight_decay': {
        'values': [0.0, 0.2,0.3,0.4,0.5]
    },
}

sweep_config['parameters'] = parameters_dict

sweep_id = wandb.sweep(sweep_config, project='hatespeech')

# define function to compute metrics
from datasets import load_metric
import numpy as np

def compute_metrics_fn(eval_preds):
    metrics = dict()
  
    accuracy_metric = load_metric('accuracy')
    precision_metric = load_metric('precision')
    recall_metric = load_metric('recall')
    f1_metric = load_metric('f1')

    logits = eval_preds.predictions
    labels = eval_preds.label_ids
    preds = np.argmax(logits, axis=-1)  
  
    metrics.update(accuracy_metric.compute(predictions=preds, references=labels))
    metrics.update(precision_metric.compute(predictions=preds, references=labels, average='weighted'))
    metrics.update(recall_metric.compute(predictions=preds, references=labels, average='weighted'))
    metrics.update(f1_metric.compute(predictions=preds, references=labels, average='weighted'))

    return metrics

import torch

def collate_fn(examples):
    text = torch.stack([example['text'] for example in examples])
    labels = torch.tensor([example['labels'] for example in examples])
    return {'text': text, 'labels': labels}

from transformers import TrainingArguments, Trainer

def train(config=None):
    with wandb.init(config=config):
    # set sweep configuration
      config = wandb.config

      # set training arguments
      training_args = TrainingArguments(
          output_dir='hyper',
          overwrite_output_dir=True,
          report_to='wandb',  # Turn on Weights & Biases logging
          num_train_epochs=config.epochs,
          learning_rate=config.learning_rate,
          weight_decay=config.weight_decay,
          per_device_train_batch_size=config.batch_size,
          per_device_eval_batch_size=16,
          save_strategy='epoch',
          evaluation_strategy='epoch',
          logging_strategy='epoch',
          load_best_model_at_end=True,
          #fp16=True
    )

    # define training loop
    trainer = Trainer(
        # model,
        model_init=model_init,
        args=training_args,
        data_collator=collate_fn,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        compute_metrics=compute_metrics_fn
    )

    # start training loop
    trainer.train()

wandb.agent(sweep_id, train, count=20)

执行脚本时,我得到以下错误消息:

wandb: ERROR Run jav32 errored: KeyError('text')

我试图用数据框中的列调整collate_fn函数,但它不起作用。那么我必须如何修改代码才能使它起作用呢?
先谢了!

plicqrtu

plicqrtu1#

您可以发布完整的堆栈跟踪吗?
我认为最好先调试train函数,实际上可以通过从数据集采样并整理输出来调试collate函数:

batch_list = [train_dataset[i] for i in range(8)]

collate_fn(batch_list)

相关问题