llama_index [问题]:自定义回调处理程序

xienkqul  于 2个月前  发布在  其他
关注(0)|答案(7)|浏览(20)

问题验证

  • 我已在文档和Discord中寻找答案。

问题

你好,
我正在尝试在分层代理设置中构建一个自定义回调处理程序,其中一个代理作为另一个代理的工具进入。
我如何在最底层代理的custom_query方法中获取当前请求的event_id?

8qgya5xd

8qgya5xd1#

在分层代理设置中,要获取最底层代理的 event_id 方法中的当前请求的 custom_query ,可以使用 UpTrainCallbackHandler 类中的 on_event_start 方法。此方法处理事件的开始并返回 event_id
以下是来自 on_event_start 方法的相关代码片段:

def on_event_start(
        self,
        event_type: CBEventType,
        payload: Any = None,
        event_id: str = "",
        parent_id: str = "",
        **kwargs: Any,
    ) -> str:
        """Run when an event starts and return id of event."""
        event = CBEvent(event_type, payload=payload, id_=event_id)
        self._event_pairs_by_id[event.id_].append(event)

        if event_type is CBEventType.QUERY:
            self.schema.question = payload["query_str"]
        if event_type is CBEventType.TEMPLATING and "template_vars" in payload:
            template_vars = payload["template_vars"]
            self.schema.context = template_vars.get("context_str", "")
        elif event_type is CBEventType.RERANKING and "nodes" in payload:
            self.schema.eval_types.add("reranking")
            # Store old context data
            self.schema.old_context = [node.text for node in payload["nodes"]]
        elif event_type is CBEventType.SUB_QUESTION:
            # For the first sub question, store parent question and parent id
            if "sub_question" not in self.schema.eval_types:
                self.schema.parent_question = self.schema.question
                self.schema.eval_types.add("sub_question")
            # Store sub question data - question and parent id
            self.schema.sub_question_parent_id = parent_id
        return event_id

在此方法中,最后返回 event_id ,您可以在 custom_query 方法中使用它来获取当前请求的 event_id
此外,您还可以创建自定义事件处理程序来捕获和处理这些事件,包括 event_id 。以下是如何创建自定义事件处理程序的示例:

from llama_index.core.instrumentation.events import BaseEvent
from llama_index.core.instrumentation.event_handlers import BaseEventHandler

class ExampleEventHandler(BaseEventHandler):
    """Example event handler."""

    events: List[BaseEvent] = []

    @classmethod
    def class_name(cls) -> str:
        """Class name."""
        return "ExampleEventHandler"

    def handle(self, event: BaseEvent) -> None:
        """Logic for handling event."""
        print("-----------------------")
        print(event.id_)
        print(event.timestamp)
        print(event.span_id)
        print(f"Event type: {event.class_name()}")
        self.events.append(event)
        print("-----------------------")

然后将此事件处理程序附加到调度器以监视事件和跨度的到来:

from llama_index.core.instrumentation import get_dispatcher

# root dispatcher
root_dispatcher = get_dispatcher()

# register event handler
event_handler = ExampleEventHandler()
root_dispatcher.add_event_handler(event_handler)

通过使用这些工具,您可以有效地管理并检索分层代理设置中的 event_id

  • 要继续与 Dosu 交谈,请提及 @dosu。*
dzhpxtsq

dzhpxtsq2#

@dosu So, if the event_id is returned at the end, how to retrieve it in my custom_query method to get the event_id for the current request. Please provide code for that

zpf6vheq

zpf6vheq3#

要获取当前请求的 event_id ,您可以使用 CallbackManager 来管理和访问事件堆栈。以下是如何实现这一点的示例:

  • 定义自定义回调处理程序:
  • 创建一个捕获事件开始时的 event_id 的自定义回调处理程序。
from llama_index.core.callbacks.base import BaseCallbackHandler, CBEventType, CBEvent
from typing import Any, Dict, List
from collections import defaultdict

class CustomCallbackHandler(BaseCallbackHandler):
    def __init__(self):
        super().__init__()
        self._event_pairs_by_id: Dict[str, List[CBEvent]] = defaultdict(list)

    def on_event_start(
        self,
        event_type: CBEventType,
        payload: Any = None,
        event_id: str = "",
        parent_id: str = "",
        **kwargs: Any,
    ) -> str:
        event = CBEvent(event_type, payload=payload, id_=event_id)
        self._event_pairs_by_id[event.id_].append(event)
        return event_id

    def get_current_event_id(self) -> str:
        # Assuming the last event in the stack is the current event
        if self._event_pairs_by_id:
            return list(self._event_pairs_by_id.keys())[-1]
        return ""
  • 将自定义回调处理程序附加到调度器:
  • 使用调度器注册自定义回调处理程序。
from llama_index.core.instrumentation import get_dispatcher

# root dispatcher
root_dispatcher = get_dispatcher()

# register custom callback handler
custom_callback_handler = CustomCallbackHandler()
root_dispatcher.add_event_handler(custom_callback_handler)
  • 在您的 custom_query 方法中检索 event_id:
  • 使用自定义回调处理程序获取当前的 event_id
class RAGStringQueryEngine(CustomQueryEngine):
    """RAG String Query Engine."""

    retriever: BaseRetriever
    response_synthesizer: BaseSynthesizer
    llm: OpenAI
    qa_prompt: PromptTemplate

    def custom_query(self, query_str: str):
        # Retrieve the current event_id
        current_event_id = custom_callback_handler.get_current_event_id()
        print(f"Current Event ID: {current_event_id}")

        nodes = self.retriever.retrieve(query_str)
        context_str = "\n\n".join([n.node.get_content() for n in nodes])
        response = self.llm.complete(
            qa_prompt.format(context_str=context_str, query_str=query_str)
        )

        return str(response)

在此设置中, CustomCallbackHandler 在事件开始时捕获 event_id 并提供一种检索当前 event_id 的方法。然后,您可以在 custom_query 方法中使用此处理程序获取当前请求的 event_id

  • 要继续与 Dosu 交谈,请提及 @dosu。
mbjcgjjk

mbjcgjjk4#

当我们在Web API中使用代理时,多个请求会并行进行,这个方法是否有效?

gg0vcinb

gg0vcinb5#

要确保event_id检索是线程安全的,并且正确地与Web API中的当前请求相对应,其中多个请求正在并行处理,您可以使用上下文本地存储。在Python中,可以使用contextvars模块实现这一点。以下是一个示例:

  1. 使用上下文本地存储定义自定义回调处理程序:
from llama_index.core.callbacks.base import BaseCallbackHandler, CBEventType, CBEvent
from contextvars import ContextVar
from typing import Any, Dict, List
from collections import defaultdict

# Create a context-local storage object
event_id_var = ContextVar('event_id')

class CustomCallbackHandler(BaseCallbackHandler):
    def __init__(self):
        super().__init__()
        self._event_pairs_by_id: Dict[str, List[CBEvent]] = defaultdict(list)

    def on_event_start(
        self,
        event_type: CBEventType,
        payload: Any = None,
        event_id: str = "",
        parent_id: str = "",
        **kwargs: Any,
    ) -> str:
        event = CBEvent(event_type, payload=payload, id_=event_id)
        self._event_pairs_by_id[event.id_].append(event)
        event_id_var.set(event_id)  # Set the event_id in context-local storage
        return event_id

    def get_current_event_id(self) -> str:
        return event_id_var.get("")
  1. 将自定义回调处理程序附加到调度器:
from llama_index.core.instrumentation import get_dispatcher

# root dispatcher
root_dispatcher = get_dispatcher()

# register custom callback handler
custom_callback_handler = CustomCallbackHandler()
root_dispatcher.add_event_handler(custom_callback_handler)
  1. 在您的custom_query方法中检索event_id:
class RAGStringQueryEngine(CustomQueryEngine):
    """RAG String Query Engine."""

    retriever: BaseRetriever
    response_synthesizer: BaseSynthesizer
    llm: OpenAI
    qa_prompt: PromptTemplate

    def custom_query(self, query_str: str):
        # Retrieve the current event_id
        current_event_id = custom_callback_handler.get_current_event_id()
        print(f"Current Event ID: {current_event_id}")

        nodes = self.retriever.retrieve(query_str)
        context_str = "\n\n".join([n.node.get_content() for n in nodes])
        response = self.llm.complete(
            qa_prompt.format(context_str=context_str, query_str=query_str)
        )

        return str(response)

通过使用ContextVar,每个请求的event_id被正确地隔离和线程安全地处理,确保即使在并行处理多个请求时,event_id检索也能正常工作。

  • 要继续讨论1(https://github.com/run-llama/llama_index/blob/main/docs/docs/examples/discover_llamaindex/document_management/discord_dumps/help_channel_dump_06_02_23.json),请提及@dosu。*
vaqhlq81

vaqhlq816#

在子代理的自定义查询工具方法中,如何获取最顶层父代理的event_id?

dfty9e19

dfty9e197#

要获取子代理的 custom_query 方法中顶层父代理的 event_id ,您可以使用一个 ContextVar 来跟踪事件层次结构并存储根事件ID。以下是如何实现这一点的方法:

  • 使用上下文本地存储定义自定义回调处理程序:
  • 在事件开始时跟踪根事件ID。
from llama_index.core.callbacks.base import BaseCallbackHandler, CBEventType, CBEvent
from contextvars import ContextVar
from typing import Any, Dict, List
from collections import defaultdict

# Create context-local storage objects
event_id_var = ContextVar('event_id')
root_event_id_var = ContextVar('root_event_id')

class CustomCallbackHandler(BaseCallbackHandler):
    def __init__(self):
        super().__init__()
        self._event_pairs_by_id: Dict[str, List[CBEvent]] = defaultdict(list)

    def on_event_start(
        self,
        event_type: CBEventType,
        payload: Any = None,
        event_id: str = "",
        parent_id: str = "",
        **kwargs: Any,
    ) -> str:
        event = CBEvent(event_type, payload=payload, id_=event_id)
        self._event_pairs_by_id[event.id_].append(event)
        event_id_var.set(event_id)  # Set the event_id in context-local storage

        # Set the root event ID if it's not already set
        if not root_event_id_var.get(None):
            root_event_id_var.set(event_id)

        return event_id

    def get_current_event_id(self) -> str:
        return event_id_var.get("")

    def get_root_event_id(self) -> str:
        return root_event_id_var.get("")
  • 将自定义回调处理程序附加到调度器:
  • 使用调度器注册自定义回调处理程序。
from llama_index.core.instrumentation import get_dispatcher

# root dispatcher
root_dispatcher = get_dispatcher()

# register custom callback handler
custom_callback_handler = CustomCallbackHandler()
root_dispatcher.add_event_handler(custom_callback_handler)
  • 在您的 custom_query 方法中检索根 event_id :
  • 使用自定义回调处理程序获取根 event_id
class RAGStringQueryEngine(CustomQueryEngine):
    """RAG String Query Engine."""

    retriever: BaseRetriever
    response_synthesizer: BaseSynthesizer
    llm: OpenAI
    qa_prompt: PromptTemplate

    def custom_query(self, query_str: str):
        # Retrieve the root event_id
        root_event_id = custom_callback_handler.get_root_event_id()
        print(f"Root Event ID: {root_event_id}")

        nodes = self.retriever.retrieve(query_str)
        context_str = "\n\n".join([n.node.get_content() for n in nodes])
        response = self.llm.complete(
            qa_prompt.format(context_str=context_str, query_str=query_str)
        )

        return str(response)

通过使用 ContextVar 存储根事件ID,您可以确保在您的 custom_query 方法中正确跟踪和访问根 event_id ,即使在处理分层代理和并行请求 1(https://github.com/run-llama/llama_index/blob/main/llama-index-core/llama_index/core/callbacks/base.py)2(https://github.com/run-llama/llama_index/blob/main/llama-index-legacy/llama_index/legacy/callbacks/base.py) 时也是如此。

  • 要继续与 Dosu 交谈,请提及 @dosu。

相关问题