我正在学习this教程,学习如何使用Huggingface和Wandb进行超参数调优。大部分功能都能正常工作,但我不太明白“collate_fn”函数在做什么,以及如何根据我的用例调整它。我的数据集如下所示:str类型的“text”列,包含tweet的内容,还有一个int值的“majority_votes”。
import wandb
wandb.login()
%env WANDB_PROJECT=vit_snacks_sweeps
%env WANDB_LOG_MODEL=true
%env WANDD_NOTEBOOK_NAME=Trainer_Huggingface
from transformers import BertTokenizer, BertForSequenceClassification
import numpy as np
from sklearn.model_selection import train_test_split
with s3.open(f"{bucket_name}/KFOLD1/{train_file_name}",'r') as file:
data = pd.read_csv(file)
with s3.open(f"{bucket_name}/KFOLD1/{test_file_name}",'r') as file:
test_data = pd.read_csv(file)
data = data[["Text", "majority_vote"]]
test_data = test_data[["Text", "majority_vote"]]
data.rename(columns={'Text': 'text', 'majority_vote': 'labels'}, inplace=True)
test_data.rename(columns={'Text': 'text', 'majority_vote': 'labels'}, inplace=True)
# Define pre trained tokenizer and model
tokenizer = BertTokenizer.from_pretrained(model_name)
#model = BertForSequenceClassification.from_pretrained(model_name, num_labels=2)
# ----- 1. Preprocess data -----#
# Preprocess data
X = list(data["text"])
y = list(data["labels"])
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=11)
X_train_tokenized = tokenizer(X_train, padding=True, truncation=True, max_length=512)
X_val_tokenized = tokenizer(X_val, padding=True, truncation=True, max_length=512)
# Create torch dataset
class Dataset(torch.utils.data.Dataset):
def __init__(self, encodings, labels=None):
self.encodings = encodings
self.labels = labels
def __getitem__(self, idx):
item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
if self.labels:
item["labels"] = torch.tensor(self.labels[idx])
return item
def __len__(self):
return len(self.encodings["input_ids"])
train_dataset = Dataset(X_train_tokenized, y_train)
val_dataset = Dataset(X_val_tokenized, y_val)
def model_init():
vit_model = model = BertForSequenceClassification.from_pretrained(model_name, num_labels=2)
return vit_model
# method
sweep_config = {
'method': 'random'
}
# hyperparameters
parameters_dict = {
'epochs': {
'value': 1
},
'batch_size': {
'values': [8, 16, 32]
},
'learning_rate': {
'distribution': 'log_uniform_values',
'min': 1e-5,
'max': 1e-3
},
'weight_decay': {
'values': [0.0, 0.2,0.3,0.4,0.5]
},
}
sweep_config['parameters'] = parameters_dict
sweep_id = wandb.sweep(sweep_config, project='hatespeech')
# define function to compute metrics
from datasets import load_metric
import numpy as np
def compute_metrics_fn(eval_preds):
metrics = dict()
accuracy_metric = load_metric('accuracy')
precision_metric = load_metric('precision')
recall_metric = load_metric('recall')
f1_metric = load_metric('f1')
logits = eval_preds.predictions
labels = eval_preds.label_ids
preds = np.argmax(logits, axis=-1)
metrics.update(accuracy_metric.compute(predictions=preds, references=labels))
metrics.update(precision_metric.compute(predictions=preds, references=labels, average='weighted'))
metrics.update(recall_metric.compute(predictions=preds, references=labels, average='weighted'))
metrics.update(f1_metric.compute(predictions=preds, references=labels, average='weighted'))
return metrics
import torch
def collate_fn(examples):
text = torch.stack([example['text'] for example in examples])
labels = torch.tensor([example['labels'] for example in examples])
return {'text': text, 'labels': labels}
from transformers import TrainingArguments, Trainer
def train(config=None):
with wandb.init(config=config):
# set sweep configuration
config = wandb.config
# set training arguments
training_args = TrainingArguments(
output_dir='hyper',
overwrite_output_dir=True,
report_to='wandb', # Turn on Weights & Biases logging
num_train_epochs=config.epochs,
learning_rate=config.learning_rate,
weight_decay=config.weight_decay,
per_device_train_batch_size=config.batch_size,
per_device_eval_batch_size=16,
save_strategy='epoch',
evaluation_strategy='epoch',
logging_strategy='epoch',
load_best_model_at_end=True,
#fp16=True
)
# define training loop
trainer = Trainer(
# model,
model_init=model_init,
args=training_args,
data_collator=collate_fn,
train_dataset=train_dataset,
eval_dataset=val_dataset,
compute_metrics=compute_metrics_fn
)
# start training loop
trainer.train()
wandb.agent(sweep_id, train, count=20)
执行脚本时,我得到以下错误消息:
wandb: ERROR Run jav32 errored: KeyError('text')
我试图用数据框中的列调整collate_fn函数,但它不起作用。那么我必须如何修改代码才能使它起作用呢?
先谢了!
1条答案
按热度按时间plicqrtu1#
您可以发布完整的堆栈跟踪吗?
我认为最好先调试
train
函数,实际上可以通过从数据集采样并整理输出来调试collate
函数: