class ContextChatEngine(BaseChatEngine):
"""
Context Chat Engine.
Uses a retriever to retrieve a context, set the context in the system prompt,
and then uses an LLM to generate a response, for a fluid chat experience.
"""
def __init__(
self,
retriever: BaseRetriever,
llm: LLM,
memory: BaseMemory,
prefix_messages: List[ChatMessage],
node_postprocessors: Optional[List[BaseNodePostprocessor]] = None,
context_template: Optional[str] = None,
callback_manager: Optional[CallbackManager] = None,
filters: Optional[List[Any]] = None, # Add filters parameter
) -> None:
self._retriever = retriever
self._llm = llm
self._memory = memory
self._prefix_messages = prefix_messages
self._node_postprocessors = node_postprocessors or []
self._context_template = context_template or DEFAULT_CONTEXT_TEMPLATE
self._filters = filters or [] # Initialize filters
self.callback_manager = callback_manager or CallbackManager([])
for node_postprocessor in self._node_postprocessors:
node_postprocessor.callback_manager = self.callback_manager
def _generate_context(self, message: str) -> Tuple[str, List[NodeWithScore]]:
"""Generate context information from a message."""
nodes = self._retriever.retrieve(message, filters=self._filters) # Use filters
for postprocessor in self._node_postprocessors:
nodes = postprocessor.postprocess_nodes(
nodes, query_bundle=QueryBundle(message)
)
context_str = "\n\n".join(
[n.node.get_content(metadata_mode=MetadataMode.LLM).strip() for n in nodes]
)
return self._context_template.format(context_str=context_str), nodes
async def _agenerate_context(self, message: str) -> Tuple[str, List[NodeWithScore]]:
"""Generate context information from a message."""
nodes = await self._retriever.aretrieve(message, filters=self._filters) # Use filters
for postprocessor in self._node_postprocessors:
nodes = postprocessor.postprocess_nodes(
nodes, query_bundle=QueryBundle(message)
)
context_str = "\n\n".join(
[n.node.get_content(metadata_mode=MetadataMode.LLM).strip() for n in nodes]
)
return self._context_template.format(context_str=context_str), nodes
5条答案
按热度按时间ar5n3qh51#
是的,有一种方法可以在基于用户角色的上下文聊天引擎中添加元数据过滤。
ContextChatEngine
类允许使用BaseNodePostprocessor
示例,这些示例可以根据元数据和用户角色进行自定义过滤节点。下面是一个如何实现基于用户角色的自定义
BaseNodePostprocessor
以过滤节点的示例:在这个例子中,
RoleBasedNodePostprocessor
根据user_role
属性过滤节点。您可以自定义_filter_node
方法以根据元数据和用户角色 1(https://github.com/run-llama/llama_index/blob/main/llama-index-core/llama_index/core/chat_engine/context.py) 实现所需的特定过滤逻辑。c6ubokkw2#
Same way you filter a query engine
index.as_chat_engine(..., filters=filters)
bqujaahr3#
过滤器只是被传递到检索器底层。
gzjq41n44#
在上面的代码中,我可以直接将过滤器作为参数添加吗?
g2ieeal75#
是的,您可以直接在初始化代码中为
ContextChatEngine
添加过滤器作为参数。您需要修改__init__
方法以包含filters
参数并更新相关方法以使用此参数。以下是如何修改代码以包含过滤器的方法:
filters
参数添加到__init__
方法中。_generate_context
和_agenerate_context
方法以使用filters
参数。此修改允许您在初始化
ContextChatEngine
时直接传递过滤器,并确保在生成上下文期间使用这些过滤器。