text-generation-inference 新的HF Mamba模型不受支持,

pw9qyyiw  于 4个月前  发布在  其他
关注(0)|答案(1)|浏览(79)

系统信息

TGI v2.2.0,使用官方的Docker镜像。

信息

  • Docker
  • CLI直接使用

任务

  • 一个官方支持的命令
  • 我自己的修改

重现过程

我使用最新的HF兼容的mamba模型启动TGI:https://huggingface.co/state-spaces/mamba-2.8b-hf
我遇到了错误:

2024-07-30T16:10:33.125286Z  WARN text_generation_router: router/src/main.rs:312: Could not parse config Error("unknown variant `mamba`, expected one of `llava_next`, `clip_vision_model`, `mistral`, `idefics`, `idefics2`, `ssm`, `gpt_bigcode`, `santacoder`, `bloom`, `mpt`, `gpt2`, `gpt_neox`, `phi`, `phi-msft`, `phi3`, `llama`, `baichuan`, `paligemma`, `gemma`, `gemma2`, `cohere`, `drbx`, `falcon`, `mixtral`, `starcoder2`, `qwen2`, `opt`, `t5`", line: 16, column: 23)
2024-07-30T16:10:38.763546Z ERROR text_generation_launcher: Method Warmup encountered an error.
Traceback (most recent call last):
  File "/opt/conda/bin/text-generation-server", line 8, in <module>
    sys.exit(app())
  File "/opt/conda/lib/python3.10/site-packages/typer/main.py", line 311, in __call__
    return get_command(self)(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/click/core.py", line 1157, in __call__
    return self.main(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/typer/core.py", line 778, in main
    return _main(
  File "/opt/conda/lib/python3.10/site-packages/typer/core.py", line 216, in _main
    rv = self.invoke(ctx)
  File "/opt/conda/lib/python3.10/site-packages/click/core.py", line 1688, in invoke
    return _process_result(sub_ctx.command.invoke(sub_ctx))
  File "/opt/conda/lib/python3.10/site-packages/click/core.py", line 1434, in invoke
    return ctx.invoke(self.callback, **ctx.params)
  File "/opt/conda/lib/python3.10/site-packages/click/core.py", line 783, in invoke
    return __callback(*args, **kwargs)
  File "/opt/conda/lib/python3.10/site-packages/typer/main.py", line 683, in wrapper
    return callback(**use_params)  # type: ignore
  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/cli.py", line 118, in serve
    server.serve(
  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/server.py", line 297, in serve
    asyncio.run(
  File "/opt/conda/lib/python3.10/asyncio/runners.py", line 44, in run
    return loop.run_until_complete(main)
  File "/opt/conda/lib/python3.10/asyncio/base_events.py", line 636, in run_until_complete
    self.run_forever()
  File "/opt/conda/lib/python3.10/asyncio/base_events.py", line 603, in run_forever
    self._run_once()
  File "/opt/conda/lib/python3.10/asyncio/base_events.py", line 1909, in _run_once
    handle._run()
  File "/opt/conda/lib/python3.10/asyncio/events.py", line 80, in _run
    self._context.run(self._callback, *self._args)
  File "/opt/conda/lib/python3.10/site-packages/grpc_interceptor/server.py", line 165, in invoke_intercept_method
    return await self.intercept(
> File "/opt/conda/lib/python3.10/site-packages/text_generation_server/interceptor.py", line 21, in intercept
    return await response
  File "/opt/conda/lib/python3.10/site-packages/opentelemetry/instrumentation/grpc/_aio_server.py", line 120, in _unary_interceptor
    raise error
  File "/opt/conda/lib/python3.10/site-packages/opentelemetry/instrumentation/grpc/_aio_server.py", line 111, in _unary_interceptor
    return await behavior(request_or_iterator, context)
  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/server.py", line 125, in Warmup
    max_supported_total_tokens = self.model.warmup(batch)
  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/models/model.py", line 104, in warmup
    self.generate_token(batch)
  File "/opt/conda/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/models/causal_lm.py", line 691, in generate_token
    logits, speculative_logits, past = self.forward(
  File "/opt/conda/lib/python3.10/site-packages/text_generation_server/models/causal_lm.py", line 681, in forward
    return outputs.logits, speculative_logits, outputs.past_key_values
AttributeError: 'MambaCausalLMOutput' object has no attribute 'past_key_values'
2024-07-30T16:10:38.768796Z ERROR warmup{max_input_length=1024 max_prefill_tokens=1024 max_total_tokens=2024 max_batch_size=None}:warmup: text_generation_client: router/client/src/lib.rs:46: Server error: 'MambaCausalLMOutput' object has no attribute 'past_key_values'
Error: WebServer(Warmup(Generation("'MambaCausalLMOutput' object has no attribute 'past_key_values'")))

预期行为

最新的mamba模型在模型配置中将model_type定义为mamba:https://huggingface.co/state-spaces/mamba-2.8b-hf/blob/main/config.json#L15
TGI期望它仅为ssm:https://github.com/huggingface/text-generation-inference/blob/v2.2.0/server/text_generation_server/models/init.py#L200
我们应该让mamba也成为Mamba的模型类型的可接受字符串。

dvtswwa3

dvtswwa31#

感谢@jonnyli1125的报告👍
我会查看一下。看起来它与之前的Mamba模型略有不同。

相关问题