llama_index [Bug]:在某些情况下,当使用PropertyGraphIndex创建chat_engine或query_engine时,会抛出错误,这破坏了index.as_chat_engine()和index.as_query_engine()方法,

xt0899hw  于 2个月前  发布在  其他
关注(0)|答案(2)|浏览(48)

Bug描述

当使用属性图索引创建chat_engine或query_engine时,如果传递了一些参数,它会抛出错误。为了解决这个问题,可以在底层的VectorStoreQuery(llama_index/core/vector_stores/types.py)中添加一个**kwargs参数,以忽略在其他地方使用的任何不必要的参数,并将其传递给它。
当前实现:

@DataClass
class VectorStoreQuery:
    """向量存储查询”。

query_embedding: Optional[List[float]] = None
similarity_top_k: int = 1
doc_ids: Optional[List[str]] = None
node_ids: Optional[List[str]] = None
query_str: Optional[str] = None
output_fields: Optional[List[str]] = None
embedding_field: Optional[str] = None

mode: VectorStoreQueryMode = VectorStoreQueryMode.DEFAULT

NOTE: only for hybrid search (0 for bm25, 1 for vector search)

alpha: Optional[float] = None

metadata filters

filters: Optional[MetadataFilters] = None

only for mmr

mmr_threshold: Optional[float] = None

NOTE: currently only used by postgres hybrid search

sparse_top_k: Optional[int] = None

NOTE: return top k results from hybrid search. similarity_top_k is used for dense search top k

hybrid_top_k: Optional[int] = None

修复后的实现:

@DataClass
class VectorStoreQuery:
    """向量存储查询”。

query_embedding: Optional[List[float]] = None
similarity_top_k: int = 1
doc_ids: Optional[List[str]] = None
node_ids: Optional[List[str]] = None
query_str: Optional[str] = None
output_fields: Optional[List[str]] = None
embedding_field: Optional[str] = None
mode: VectorStoreQueryMode = VectorStoreQueryMode.DEFAULT
alpha: Optional[float] = None
filters: Optional[MetadataFilters] = None
mmr_threshold: Optional[float] = None
sparse_top_k: Optional[int] = None
hybrid_top_k: Optional[int] = None

def init(self, **kwargs):
self.query_embedding = kwargs.get('query_embedding', self.query_embedding)
self.similarity_top_k = kwargs.get('similarity_top_k', self.similarity_top_k)
self.doc_ids = kwargs.get('doc_ids', self.doc_ids)
self.node_ids = kwargs.get('node_ids', self.node_ids)
self.query_str = kwargs.get('query_str', self.query_str)
self.output_fields = kwargs.get('output_fields', self.output_fields)
self.embedding_field = kwargs.get('embedding_field', self.embedding_field)
self.mode = kwargs.get('mode', self.mode)
self.alpha = kwargs.get('alpha', self.alpha)
self.filters = kwargs.get('filters', self.filters)
self.mmr_threshold = kwargs.get('mmr_threshold', self.mmr_threshold)
self.sparse_top_k = kwargs.get('sparse_top_k', self.sparse_top_k)
self.hybrid_top_k = kwargs.get('hybrid_top_k', self.hybrid_top_k)

版本

0.10.52

重现步骤

chat_engine = property_graph_index.as_chat_engine(
    chat_mode=chat_mode,
    llm=llm,
    similarity_top_k=similarity_top_k,

All the parameters below throw an error because of VectorStoreQuery.init()

    # use_async=True, #This is already passed in by PGRetriever.
    service_context=service_context,
    response_mode=response_mode,
    verbose=verbose,
    max_function_calls=max_agent_iterations,
    max_iterations=max_agent_iterations,
    node_postprocessors=node_postprocessors,
)

相关日志/回溯

response: AgentChatResponse = chat_engine.chat(input_text)
File "/home/gich2009/venv/lib/python3.10/site-packages/llama_index/core/instrumentation/dispatcher.py", line 230, in wrapper
result = func(*args, **kwargs)
File "/home/gich2009/venv/lib/python3.10/site-packages/llama_index/core/callbacks/utils.py", line 41, in wrapper
return func(self, *args, **kwargs)
File "/home/gich2009/venv/lib/python3.10/site-packages/llama_index/core/chat_engine/condense_plus_context.py", line 292, in chat
chat_messages, context_source, context_nodes = self._run_c3(
File "/home/gich2009/venv/lib/python3.10/site-packages/llama_index/core/chat_engine/condense_plus_context.py", line 208, in _run_c3
context_str, context_nodes = self._retrieve_context(condensed_question)
File "/home/gich2009/venv/lib/python3.10/site-packages/llama_index/core/chat_engine/condense_plus_context.py", line 169, in _retrieve_context
nodes = self._retriever.retrieve(message)
File "/home/gich2009/venv/lib/python3.10/site-packages/llama_index/core/instrumentation/dispatcher.py", line 230, in wrapper
result = func(*args, **kwargs)
File "/home/gich2009/venv/lib/python3.10/site-packages/llama_index/core/base/base_retriever.py", line 243, in retrieve
nodes = self._retrieve(query_bundle)
File "/home/gich2009/venv/lib/python3.10/site-packages/llama_index/core/instrumentation/dispatcher.py", line 230, in wrapper
result = func(*args, **kwargs)
File "/home/gich2009/venv/lib/python3.10/site-packages/llama_index/core/indices/property_graph/retriever.py", line 52, in _retrieve
return asyncio_run(self._aretrieve(query_bundle))
File "/home/gich2009/venv/lib/python3.10/site-packages/llama_index/core/async_utils.py", line 47, in asyncio_run
raise type(e)(f"Error running coroutine: {e!s}") from e
TypeError: Error running coroutine: VectorStoreQuery.init() got an unexpected keyword argument 'response_mode'

u2nhd7ah

u2nhd7ah1#

要修复在创建 chat_enginequery_enginePropertyGraphIndex 中的错误,您需要在 llama_index/core/vector_stores/types.py 中的 VectorStoreQuery 类中添加一个 **kwargs 参数。以下是带有 **kwargs 参数的修改后的 VectorStoreQuery 类:

@dataclass
class VectorStoreQuery:
"""Vector store query."""

query_embedding: Optional[List[float]] = None
similarity_top_k: int = 1
doc_ids: Optional[List[str]] = None
node_ids: Optional[List[str]] = None
query_str: Optional[str] = None
output_fields: Optional[List[str]] = None
embedding_field: Optional[str] = None
mode: VectorStoreQueryMode = VectorStoreQueryMode.DEFAULT
alpha: Optional[float] = None
filters: Optional[MetadataFilters] = None
mmr_threshold: Optional[float] = None
sparse_top_k: Optional[int] = None
hybrid_top_k: Optional[int] = None

# Add **kwargs parameter
**kwargs: Any

另外,确保 `llama_index/core/indices/property_graph/sub_retrievers/vector.py` 中的 `VectorContextRetriever` 类在创建 `VectorStoreQuery` 时正确传递这些 `**kwargs` :

class VectorContextRetriever(BasePGRetriever):
"""A retriever that uses a vector store to retrieve nodes based on a query."""

def __init__(
    self,
    graph_store: PropertyGraphStore,
    include_text: bool = True,
    embed_model: Optional[BaseEmbedding] = None,
    vector_store: Optional[VectorStore] = None,
    similarity_top_k: int = 4,
    path_depth: int = 1,
    filters: Optional[MetadataFilters] = None,
    **kwargs: Any
) -> None:
    self._retriever_kwargs = kwargs or {}
    self._embed_model = embed_model or Settings.embed_model
    self._similarity_top_k = similarity_top_k
    self._vector_store = vector_store
    self._path_depth = path_depth
    self._filters = filters

    super().__init__(graph_store=graph_store, include_text=include_text, **kwargs)

def _get_vector_store_query(self, query_bundle: QueryBundle) -> VectorStoreQuery:
    if query_bundle.embedding is None:
        query_bundle.embedding = self._embed_model.get_agg_embedding_from_queries(
            query_bundle.embedding_strs
        )
    return VectorStoreQuery(
        query_embedding=query_bundle.embedding,
        similarity_top_k=self._similarity_top_k,
        filters=self._filters,
        **self._retriever_kwargs,
    )

async def _aget_vector_store_query(
    self, query_bundle: QueryBundle
) -> VectorStoreQuery:
    if query_bundle.embedding is None:
        query_bundle.embedding = (
            await self._embed_model.aget_agg_embedding_from_queries(
                query_bundle.embedding_strs
            )
        )

    return VectorStoreQuery(
        query_embedding=query_bundle.embedding,
        similarity_top_k=self._similarity_top_k,
        filters=self._filters,
        **self._retriever_kwargs,
    )

这确保了任何额外的关键字参数都能通过 `VectorContextRetriever` 正确传递给 `VectorStoreQuery` [[1]](https://github.com/run-llama/llama_index/blob/main/llama-index-core/llama_index/core/vector_stores/types.py)[[2]](https://github.com/run-llama/llama_index/blob/main/llama-index-core/llama_index/core/indices/property_graph/sub_retrievers/vector.py)。

* 要继续与 [Dosu](https://dosu.dev) 对话,请提及 @dosu。*
dgenwo3n

dgenwo3n2#

另一个建议是PGRetriever也可以检查是否将use_async作为kwargs参数传递给它:
当前实现:
return PGRetriever(sub_retrievers, use_async=self._use_async, **kwargs)
建议实现:
use_async = kwargs.pop("use_async", self._use_async)
return PGRetriever(sub_retrievers, use_async=use_async, **kwargs)

相关问题