ludwig 重复惩罚bug出错

xvw2m8pv  于 4个月前  发布在  其他
关注(0)|答案(1)|浏览(43)

描述错误

在为repetition_penalty定义值并微调模型时,出现以下错误:

Prediction:   0%|                                                                                                                                                                            | 0/1 [00:00<?, ?it/s]Traceback (most recent call last):
  File "/projects/llmodels/karli_iqa/train.py", line 93, in <module>
    main()
  File "/projects/llmodels/karli_iqa/train.py", line 73, in main
    predictions = model.predict(test_df, generation_config={'temperature': 0.1, 'max_new_tokens': 26, 'repetition_penalty': 1.1})[0]
  File "/.pyenv/versions/karli_iqa/lib/python3.10/site-packages/ludwig/api.py", line 1083, in predict
    predictions = predictor.batch_predict(
  File "/.pyenv/versions/karli_iqa/lib/python3.10/site-packages/ludwig/models/predictor.py", line 143, in batch_predict
    preds = self._predict(batch)
  File "/.pyenv/versions/karli_iqa/lib/python3.10/site-packages/ludwig/models/predictor.py", line 189, in _predict
    outputs = self._predict_on_inputs(inputs)
  File "/.pyenv/versions/karli_iqa/lib/python3.10/site-packages/ludwig/models/predictor.py", line 345, in _predict_on_inputs
    return self.dist_model.generate(inputs)
  File "/.pyenv/versions/karli_iqa/lib/python3.10/site-packages/ludwig/models/llm.py", line 334, in generate
    model_outputs = self.model.generate(
  File "/.pyenv/versions/karli_iqa/lib/python3.10/site-packages/peft/peft_model.py", line 1130, in generate
    outputs = self.base_model.generate(**kwargs)
  File "/.pyenv/versions/karli_iqa/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/.pyenv/versions/karli_iqa/lib/python3.10/site-packages/transformers/generation/utils.py", line 1764, in generate
    return self.sample(
  File "/.pyenv/versions/karli_iqa/lib/python3.10/site-packages/transformers/generation/utils.py", line 2874, in sample
    next_token_scores = logits_processor(input_ids, next_token_logits)
  File "/.pyenv/versions/karli_iqa/lib/python3.10/site-packages/transformers/generation/logits_process.py", line 97, in __call__
    scores = processor(input_ids, scores)
  File "/.pyenv/versions/karli_iqa/lib/python3.10/site-packages/transformers/generation/logits_process.py", line 332, in __call__
    score = torch.gather(scores, 1, input_ids)
RuntimeError: gather(): Expected dtype int64 for index

重现问题

重现问题的步骤:
配置文件:

model_type: llm
base_model: huggyllama/llama-7b
# base_model: meta-llama/Llama-2-7b-hf
# base_model: meta-llama/Llama-2-13b-hf

model_parameters:
  trust_remote_code: true

backend:
  type: local
  cache_dir: ./ludwig_cache

input_features:
  - name: input
    type: text
    preprocessing:
      max_sequence_length: 128

output_features:
  - name: output
    type: text
    preprocessing:
      max_sequence_length: 64

prompt:
  template: >-
    ### User: {input}

    ### Assistant:

generation:
  temperature: 0.1
  max_new_tokens: 32
  repetition_penalty: 1.1
  # remove_invalid_values: true

adapter:
  type: lora
  dropout: 0.05
  r: 8

quantization:
  bits: 4

preprocessing:
  global_max_sequence_length: 256
  split:
    type: fixed

trainer:
  type: finetune
  epochs: 1
  batch_size: 3
  eval_batch_size: 2
  gradient_accumulation_steps: 16
  learning_rate: 0.0004
  learning_rate_scheduler:
    warmup_fraction: 0.03

Lora微调工作正常。但是当尝试进行推理时,如上所述的错误被抛出。

预期行为

预测应该正常进行。

环境信息(请填写以下信息):

  • 操作系统:Ubuntu 22.04
  • Python版本:3.10.0
  • Ludwig版本:0.9.1
sy5wg1nm

sy5wg1nm1#

@arnavgarg1 kind ping since its been couple of months now

相关问题