当前环境
The output of `python collect_env.py`
🐛 描述bug
from vllm import LLM, SamplingParams
sampling_params = SamplingParams(temperature=0, max_tokens=2048)
llm = LLM(model="Llama-3-8B",tensor_parallel_size=4, trust_remote_code=True)
outputs = llm.generate(prompts=prompts, sampling_params=sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(prompt + generated_text)
我使用vllm 0.4.1和0.4.2运行这段代码,并在sampler.py中打印logits。但是我发现当tensor_parallel_size为4时,0.4.1和0.4.2的logits不同,而当tensor_parallel_size为2或1时,它们的logits是相同的。
TP4第一个token的logits:
0.4.2
0.4.1
TP2第一个token的logits:
0.4.2
0.4.1
1条答案
按热度按时间s1ag04yj1#
我使用FlashAttention后端,版本为flash-attn==2.5.2。