OpenPrompt 如何解决LLaMA输出的logits相等性问题

xoshrz7s  于 6个月前  发布在  Git
关注(0)|答案(3)|浏览(124)

This is my code.

from datasets import load_dataset
from transformers import set_seed
from openprompt.data_utils import InputExample
import os
from tqdm import tqdm
device = "cuda"
classes = ["negative", "positive"]
set_seed(1024)
from accelerate import Accelerator
accelerator = Accelerator()
data_path = 'data'
test_path = os.path.join(data_path, 'test.json')
test_dataset = load_dataset('json', data_files=test_path)['train'] # 1 positive 0 negative
y_true = test_dataset['label']
dataset = []
import copy
data = []
copy_test_dataset = copy.deepcopy(test_dataset)
for example in copy_test_dataset:
temp_data = {"guid": example["label"], "text_a": example["sentence"]}
data.append(temp_data)
for item in data:
dataset.append(InputExample(guid=item["guid"], text_a=item["text_a"]))
from openprompt import plms
from openprompt.plms import *
from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer
plms._MODEL_CLASSES["llama"]= ModelClass(**{"config": LlamaConfig, "tokenizer": LlamaTokenizer, "model": LlamaForCausalLM, "wrapper": LMTokenizerWrapper})
from openprompt.plms import load_plm
plm, tokenizer, model_config, WrapperClass = load_plm("llama", "huggyllama/llama-7b")
tokenizer.pad_token_id = 0
from openprompt.prompts import ManualTemplate
promptTemplate = ManualTemplate(
text=' {"placeholder":"text_a"} This sentence was {"mask"}',
tokenizer=tokenizer,
)
from openprompt.prompts import ManualVerbalizer
promptVerbalizer = ManualVerbalizer(classes=classes,
label_words={"negative": ["bad"], "positive": ["good", "wonderful", "great"], },
tokenizer=tokenizer, )
from openprompt import PromptForClassification
promptModel = PromptForClassification(template=promptTemplate, plm=plm, verbalizer=promptVerbalizer, )
from openprompt import PromptDataLoader
data_loader = PromptDataLoader(dataset=dataset, tokenizer=tokenizer, template=promptTemplate,
tokenizer_wrapper_class=WrapperClass, batch_size=1)
import torch
promptModel.eval()
print(promptModel)
promptModel, data_loader = accelerator.prepare(promptModel, data_loader)
promptModel.to(device)
predictions = []
with torch.no_grad():
for batch in tqdm(data_loader, desc="Processing batches"):
batch = {k: v.to(device) for k, v in batch.items()}
print(batch)
logits = promptModel(batch)
print(logits)
exit()
preds = torch.argmax(logits, dim=-1)
for i in preds:
predictions.append(i.item())
from sklearn.metrics import accuracy_score
accuracy = accuracy_score(y_true, predictions)
print('Accuracy: %.2f' % (accuracy * 100))

The output logits is :

tensor(-1.3863, -1.3863)

a7qyws3x

a7qyws3x1#

我也面临这个问题。@shuaizhao95,你能解决这个问题吗?

kzmpq1sx

kzmpq1sx2#

我也遇到了这个问题。@shuaizhao95 你解决了这个问题吗?
抱歉,我还没有解决这个问题。

0h4hbjxa

0h4hbjxa3#

normalizer.cc(51) LOG(INFO) precompiled_charsmap为空,使用身份规范化。我也遇到了这个问题,你是否能解决这个问题?

相关问题