text-generation-inference 当max_tokens为None时,输出被截断,

2sbarzqh  于 4个月前  发布在  其他
关注(0)|答案(2)|浏览(81)

系统信息
docker版本:sha-0b95693
正在使用的模型:/v1/chat/completions

信息

  • Docker
  • 直接使用CLI

任务

  • 一个官方支持的命令
  • 我自己的修改

重现

{
"messages": [
{
"role": "user",
"content": "你是谁?"
}
],
"model": "",
"seed": 42,
"max_tokens": null,
"temperature": null,
"stream": false
}

预期行为

如果我理解正确的话,当max_tokens=None时,会自动计算输出的最大长度,但有时候输出会被截断,并显示"finish_reason": "length",对吗?

b5lpy0ml

b5lpy0ml1#

你好,@paulcx!感谢你提出这个问题!我不是很确定我是否完全理解你的问题,但如果问题是“即使max_tokens为null,输出是否可以被截断”,那么答案是肯定的。这也取决于你如何设置TGI服务器,但我会提供验证max_tokens的代码路径。希望这能帮助你理解发生了什么:

text-generation-inference/router/src/validation.rs
第137行到第194行
| | let max_new_tokens:u32 = ifletSome(max_new_tokens) = max_new_tokens { |
| | max_new_tokens |
| | }else{ |
| | self.max_total_tokens.saturating_sub(input_length)asu32 |
| | }; |
| | let total_tokens = input_length + max_new_tokens asusize; |
| | |
| | // Validate MaxTotalTokens |
| | if total_tokens > self.max_total_tokens{ |
| | returnErr(ValidationError::MaxTotalTokens( |
| | self.max_total_tokens, |
| | input_length, |
| | max_new_tokens, |
| | )); |
| | } |
| | |
| | // Validate InputLength |
| | if input_length > self.max_input_length{ |
| | returnErr(ValidationError::InputLength( |
| | self.max_input_length, |
| | input_length, |
| | )); |
| | } |
| | |
| | let input_ids = encoding.get_ids()[..input_length].to_owned(); |
| | |
| | metrics::histogram!("tgi_request_input_length").record(input_length asf64); |
| | Ok((inputs,Some(input_ids), input_length, max_new_tokens)) |
| | } |
| | // Return inputs without validation |
| | else{ |
| | // In this case, we don't know the real length in tokens of the inputs |
| | // However, the inputs will be truncated by the python servers |
| | // We make sure that truncate + max_new_tokens <= self.max_total_tokens |
| | let max_new_tokens:u32 = ifletSome(max_new_tokens) = max_new_tokens { |
| | max_new_tokens |
| | }elseifletSome(truncate) = truncate { |
| | self.max_total_tokens.saturating_sub(truncate)asu32 |
| | }else{ |
| | returnErr(ValidationError::UnsetMaxNewTokens); |
| | }; |
| | letmut input_length = truncate.unwrap_or(self.max_input_length); |
| | |
| | // We don't have a tokenizer, therefore we have no idea how long is the query, let |
| | // them through and hope for the best. |
| | // Validate MaxNewTokens |
| | if(input_length asu32 + max_new_tokens) > self.max_total_tokensasu32{ |
| | input_length = input_length.saturating_sub(max_new_tokens asusize); |
| | } |
| | ok((vec![Chunk::Text(inputs)], None, input

rmbxnbpk

rmbxnbpk2#

你好,@paulcx!感谢你提出这个问题!
我不太确定我是否完全理解你的问题,但如果问题是“即使max_tokens为null,输出是否仍然可以截断”,那么答案是肯定的。这也取决于你如何设置TGI服务器,但我会提供验证max_tokens的地方的代码路径。希望这能帮助你理解发生了什么:
text-generation-inference/router/src/validation.rs
第137行到第194行 3877345
| | let max_new_tokens:u32 = ifletSome(max_new_tokens) = max_new_tokens { |
| | max_new_tokens |
| | }else{ |
| | self.max_total_tokens.saturating_sub(input_length)asu32 |
| | }; |
| | let total_tokens = input_length + max_new_tokens asusize; |
| | |
| | // Validate MaxTotalTokens |
| | if total_tokens > self.max_total_tokens{ |
| | returnErr(ValidationError::MaxTotalTokens( |
| | self.max_total_tokens, |
| | input_length, |
| | max_new_tokens, |
| | )); |
| | } |
| | |
| | // Validate InputLength |
| | if input_length > self.max_input_length{ |
| | returnErr(ValidationError::InputLength( |
| | self.max_input_length, |
| | input_length, |
| | )); |
| | } |
| | |
| | let input_ids = encoding.get_ids()[..input_length].to_owned(); |
| | |
| | metrics::histogram!("tgi_request_input_length").record(input_length asf64); |
| | Ok((inputs,Some(input_ids), input_length, max_new_tokens)) |
| | } |
| | // Return inputs without validation |
| | else{ |
| | // In this case, we don't know the real length in tokens of the inputs |
| | // However, the inputs will be truncated by the python servers |
| | // We make sure that truncate + max_new_tokens <= self.max_total_tokens |
| | let max_new_tokens:u32 = ifletSome(max_new_tokens) = max_new_tokens { |
| | max_new_tokens |
| | }elseifletSome(truncate) = truncate { |
| | self.max_total_tokens.saturating_sub(truncate)asu32 |
| | }else{ |
| | returnErr(ValidationError::UnsetMaxNewTokens); |
| | }; |
| | letmut input_length = truncate.unwrap_or(self.max_input_length); |
| | // We don't have a tokenizer, therefore we have no idea how long is the query, let them through and hope for the best. //
| // Validate MaxNewTokens //
| if(input_length asu32 + max_new_tokens) > self.max_total_tokensasu32{ |
| input_length = input_length.saturating_sub(max_new_tokens asusize); |
| } |
| | // Ok(( //
| // ...省略部分代码...//
| }) //

相关问题