vllm [Bug]:为什么0.4.1和0.4.2之间的logits不同?

z9zf31ra  于 2个月前  发布在  Git
关注(0)|答案(1)|浏览(74)

当前环境

The output of `python collect_env.py`

🐛 描述bug

from vllm import LLM, SamplingParams
sampling_params = SamplingParams(temperature=0, max_tokens=2048)
llm = LLM(model="Llama-3-8B",tensor_parallel_size=4, trust_remote_code=True)
outputs = llm.generate(prompts=prompts, sampling_params=sampling_params)
for output in outputs:
    prompt = output.prompt
    generated_text = output.outputs[0].text
    print(prompt + generated_text)

我使用vllm 0.4.1和0.4.2运行这段代码,并在sampler.py中打印logits。但是我发现当tensor_parallel_size为4时,0.4.1和0.4.2的logits不同,而当tensor_parallel_size为2或1时,它们的logits是相同的。

TP4第一个token的logits:

0.4.2

0.4.1

TP2第一个token的logits:

0.4.2

0.4.1

s1ag04yj

s1ag04yj1#

我使用FlashAttention后端,版本为flash-attn==2.5.2。

相关问题