langchain BaseChatModel.agenerate_prompt()没有正确地将kwargs传递给BaseChatModel.agenerate(),

at0kjp5o  于 4个月前  发布在  其他
关注(0)|答案(1)|浏览(64)

检查其他资源

  • 我为这个问题添加了一个非常描述性的标题。
  • 我使用集成搜索在LangChain文档中进行了搜索。
  • 我使用GitHub搜索查找了一个类似的问题,但没有找到。
  • 我确信这是LangChain中的一个bug,而不是我的代码。
  • 通过更新到LangChain的最新稳定版本(或特定集成包)无法解决此bug。

示例代码

以下是重现问题的示例代码:

def fetch_config_from_header(config: Dict[str, Any], req: Request) -> Dict[str, Any]:
    """ All supported types: 'name', 'cache', 'verbose', 'callbacks', 'tags', 'metadata', 'custom_get_token_ids', 'callback_manager', 'client', 'async_client', 'model_name', 'temperature', 'model_kwargs', 'openai_api_key', 'openai_api_base', 'openai_organization', 'openai_proxy', 'request_timeout', 'max_retries', 'streaming', 'n', 'max_tokens', 'tiktoken_model_name', 'default_headers', 'default_query', 'http_client', 'http_async_client']"""
 
    config = config.copy()
    configurable = config.get("configurable", {})
 
    if "x-model-name" in req.headers:
        configurable["model_name"] = req.headers["x-model-name"]
    else:
        raise HTTPException(401, "No model name provided")
   
    if "x-api-key" in req.headers:
        configurable["default_headers"] = {
            "Content-Type":"application/json",
            "api-key": req.headers["x-api-key"]
        }
    else:
        raise HTTPException(401, "No API key provided")
   
    if "x-model-kwargs" in req.headers:
        configurable["model_kwargs"] = json.loads(req.headers["x-model-kwargs"])
    else:
        raise HTTPException(401, "No model arguments provided")
   
    configurable["openai_api_base"] = f"https://someendpoint.com/{req.headers['x-model-name']}"
    config["configurable"] = configurable
    return config
 
chat_model = ChatOpenAI(
    model_name = "some_model",
    model_kwargs = {},
    default_headers = {},
    openai_api_key = "placeholder",
    openai_api_base = "placeholder").configurable_fields(
        model_name = ConfigurableField(id="model_name"),
        model_kwargs = ConfigurableField(id="model_kwargs"),
        default_headers = ConfigurableField(id="default_headers"),
        openai_api_base = ConfigurableField(id="openai_api_base"),
    )

chain = prompt_template | chat_model | StrOutputParser()
add_routes(
    app,
    chain.with_types(input_type=InputChat),
    path="/some_chain",
    disabled_endpoints=["playground"],
    per_req_config_modifier=fetch_config_from_header,
)

错误信息和堆栈跟踪(如果适用)

我只附加了traceback的相关部分

Traceback (most recent call last):
  File "/venv/lib/python3.12/site-packages/uvicorn/protocols/http/httptools_impl.py", line 399, in run_asgi
    result = await app(  # type: ignore[func-returns-value]
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/venv/lib/python3.12/site-packages/uvicorn/middleware/proxy_headers.py", line 70, in __call__
    return await self.app(scope, receive, send)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/venv/lib/python3.12/site-packages/fastapi/applications.py", line 1054, in __call__    await super().__call__(scope, receive, send)
  File "/venv/lib/python3.12/site-packages/starlette/applications.py", line 123, in __call__
    await self.middleware_stack(scope, receive, send)
  File "/venv/lib/python3.12/site-packages/starlette/middleware/errors.py", line 186, in __call__
    raise exc
  File "/venv/lib/python3.12/site-packages/starlette/middleware/errors.py", line 164, in __call__
    await self.app(scope, receive, _send)
  File "/venv/lib/python3.12/site-packages/starlette/middleware/exceptions.py", line 65, in __call__
    await wrap_app_handling_exceptions(self.app, conn)(scope, receive, send)
  File "/venv/lib/python3.12/site-packages/starlette/_exception_handler.py", line 64, in wrapped_app
    raise exc
  File "/venv/lib/python3.12/site-packages/starlette/_exception_handler.py", line 53, in wrapped_app
    await app(scope, receive, sender)
  File "/venv/lib/python3.12/site-packages/starlette/routing.py", line 756, in __call__
    await self.middleware_stack(scope, receive, send)
  File "/venv/lib/python3.12/site-packages/starlette/routing.py", line 776, in app
    await route.handle(scope, receive, send)
  File "/venv/lib/python3.12/site-packages/starlette/routing.py", line 297, in handle
    await self.app(scope, receive, send)
  File "/venv/lib/python3.12/site-packages/starlette/routing.py", line 77, in app
    await wrap_app_handling_exceptions(app, request)(scope, receive, send)
  File "/venv/lib/python3.12/site-packages/starlette/_exception_handler.py", line 64, in wrapped_app
    raise exc
  File "/venv/lib/python3.12/site-packages/starlette/_exception_handler.py", line 53, in wrapped_app
    await app(scope, receive, sender)
  File "/venv/lib/python3.12/site-packages/starlette/routing.py", line 72, in app
    response = await func(request)
               ^^^^^^^^^^^^^^^^^^^
  File "/venv/lib/python3.12/site-packages/fastapi/routing.py", line 278, in app
    raw_response = await run_endpoint_function(
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/venv/lib/python3.12/site-packages/fastapi/routing.py", line 191, in run_endpoint_function
    return await dependant.call(**values)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/venv/lib/python3.12/site-packages/langserve/server.py", line 530, in invoke
    return await api_handler.invoke(request)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/venv/lib/python3.12/site-packages/langserve/api_handler.py", line 835, in invoke
    output = await invoke_coro
             ^^^^^^^^^^^^^^^^^
  File "/venv/lib/python3.12/site-packages/langchain_core/runnables/base.py", line 4585, in ainvoke
    return await self.bound.ainvoke(
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/venv/lib/python3.12/site-packages/langchain_core/runnables/base.py", line 2541, in ainvoke
    input = await step.ainvoke(input, config)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/venv/lib/python3.12/site-packages/langchain_core/runnables/configurable.py", line 123, in ainvoke
    return await runnable.ainvoke(input, config, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/venv/lib/python3.12/site-packages/langchain_core/language_models/chat_models.py", line 191, in ainvoke
    llm_result = await self.agenerate_prompt(
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/venv/lib/python3.12/site-packages/langchain_core/language_models/chat_models.py", line 611, in agenerate_prompt
    return await self.agenerate(
           ^^^^^^^^^^^^^^^^^^^^^
  File "/venv/lib/python3.12/site-packages/langchain_core/language_models/chat_models.py", line 570, in agenerate
    raise exceptions[0]
  File "/venv/lib/python3.12/site-packages/langchain_core/language_models/chat_models.py", line 757, in _agenerate_with_cache
    result = await self._agenerate(
             ^^^^^^^^^^^^^^^^^^^^^^
  File "/venv/lib/python3.12/site-packages/langchain_openai/chat_models/base.py", line 667, in _agenerate
    response = await self.async_client.create(messages=message_dicts, **params)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

描述

https://github.com/langchain-ai/langchain/blob/master/libs/core/langchain_core/language_models/chat_models.py 中,kwargs仍然在generate_prompt()中有信息,如下所示。

async def agenerate_prompt(
        self,
        prompts: List[PromptValue],
        stop: Optional[List[str]] = None,
        callbacks: Callbacks = None,
        **kwargs: Any,
    ) -> LLMResult:
        prompt_messages = [p.to_messages() for p in prompts]
        return await self.agenerate(
            prompt_messages, stop=stop, callbacks=callbacks, **kwargs
        )

generate_prompt()中prompt_messageskwargs的值

langchain_core.language_models.chat_model.py BaseChatModel.agenerate_prompt 
prompt_messages: [[SystemMessage(content='some messages')]]
kwargs: {'tags': [], 'metadata': {'__useragent': 'python-requests/2.32.3', '__langserve_version': '0.2.2', '__langserve_endpoint': 'invoke', 'model_name': 'some_model', 'openai_api_base': 'https://someendpoint.com/some_model', 'run_name': None, 'run_id': None}

然而,当从generate_prompt()调用generate()时,kwargs为空,如下所示。

async def agenerate(
        self,
        messages: List[List[BaseMessage]],
        stop: Optional[List[str]] = None,
        callbacks: Callbacks = None,
        *,
        tags: Optional[List[str]] = None,
        metadata: Optional[Dict[str, Any]] = None,
        run_name: Optional[str] = None,
        run_id: Optional[uuid.UUID] = None,
        **kwargs: Any,
    ) -> LLMResult:
        """Asynchronously pass a sequence of prompts to a model and return generations.

        This method should make use of batched calls for models that expose a batched
        API.

        Use this method when you want to:
            1. take advantage of batched calls,
            2. need more output from the model than just the top generated value,
            3. are building chains that are agnostic to the underlying language model
                type (e.g., pure text completion models vs chat models).

        Args:
            messages: List of list of messages.
            stop: Stop words to use when generating. Model output is cut off at the
                first occurrence of any of these substrings.
            callbacks: Callbacks to pass through. Used for executing additional
                functionality, such as logging or streaming, throughout generation.
            **kwargs: Arbitrary additional keyword arguments. These are usually passed
                to the model provider API call.

        Returns:
            An LLMResult, which contains a list of candidate Generations for each input
                prompt and additional model provider-specific output.
        """
        params = self._get_invocation_params(stop=stop, **kwargs)

generate()的值为paramskwargs

langchain_core.language_models.chat_models.py BaseChatModel.agenerate 
params: {'model': 'some_model', 'model_name': 'some_model', 'stream': False, 'n': 1, 'temperature': 0.7, 'user': 'some_user', '_type': 'openai-chat', 'stop': None}
kwargs: {}

系统信息

langchain==0.2.5
langchain-community==0.2.5
langchain-core==0.2.9
langchain-experimental==0.0.60
langchain-openai==0.1.9
langchain-text-splitters==0.2.1
langgraph==0.1.1
langserve==0.2.2
langsmith==0.1.82
openai==1.35.3

platform = linux
python version = 3.12.4

相关问题