llama_index [问题]:如何在RAG CLI中使用本地模型?

luaexgnf  于 2个月前  发布在  其他
关注(0)|答案(1)|浏览(28)

问题验证

  • 我已经在文档和Discord上寻找答案。

问题

我想使用本地模型来使用Rag CLI,但是出现了错误。然而,Rag CLI使用的是OpenAI LLm。所以,我应该如何设置?

kb5ga3dv

kb5ga3dv1#

你好,@812406210!👋 我来帮助你解决任何bug、问题或贡献。如果你需要帮助,请随时提问!
要将设置配置为使用RAG CLI上的本地模型而不是默认的OpenAI LLM,你需要使用本地模型参数初始化 RagCLI 。以下是配置设置的相关代码:

def init_local_rag_cli(
    persist_dir: Optional[str] = None,
    verbose: bool = False,
    llm_model_name: str = "mixtral:8x7b",
    embed_model_name: str = "BAAI/bge-m3",
) -> RagCLI:
    """Init local RAG CLI."""
    docstore = SimpleDocumentStore()
    persist_dir = persist_dir or default_ragcli_persist_dir()
    chroma_client = chromadb.PersistentClient(path=persist_dir)
    chroma_collection = chroma_client.create_collection("default", get_or_create=True)
    vector_store = ChromaVectorStore(
        chroma_collection=chroma_collection, persist_dir=persist_dir
    )
    print("> Chroma collection initialized")
    llm = Ollama(model=llm_model_name, request_timeout=120.0)
    print("> LLM initialized")
    embed_model = HuggingFaceEmbedding(model_name=embed_model_name, pooling="mean")
    print("> Embedding model initialized")

    ingestion_pipeline = IngestionPipeline(
        transformations=[SentenceSplitter(), embed_model],
        vector_store=vector_store,
        docstore=docstore,
        cache=IngestionCache(),
    )

    service_context = ServiceContext.from_defaults(llm=llm, embed_model=embed_model)
    retriever = VectorStoreIndex.from_vector_store(
        ingestion_pipeline.vector_store, service_context=service_context
    ).as_retriever(similarity_top_k=8)
    response_synthesizer = CompactAndRefine(
        service_context=service_context, streaming=True, verbose=True
    )
    # define query pipeline
    query_pipeline = QueryPipeline(verbose=verbose)
    query_pipeline.add_modules(
        {
            "input": InputComponent(),
            "retriever": retriever,
            "summarizer": response_synthesizer,
        }
    )
    query_pipeline.add_link("input", "retriever")
    query_pipeline.add_link("retriever", "summarizer", dest_key="nodes")
    query_pipeline.add_link("input", "summarizer", dest_key="query_str")

    return RagCLI(
        ingestion_pipeline=ingestion_pipeline,
        llm=llm,  # optional
        persist_dir=persist_dir,
        query_pipeline=query_pipeline,
        verbose=False,
    )

要使用本地模型,请使用 llm_model_nameembed_model_name 的所需本地模型名称调用此函数。例如:

rag_cli_instance = init_local_rag_cli(
    llm_model_name="mixtral:8x7b",
    embed_model_name="your_local_embedding_model_name"
)
rag_cli_instance.cli()

这将配置RAG CLI使用指定的本地模型,而不是默认的OpenAI LLM 1(https://github.com/run-llama/llama_index/blob/main/llama-index-packs/llama-index-packs-rag-cli-local/llama_index/packs/rag_cli_local/base.py)2(https://github.com/run-llama/llama_index/blob/main/docs/docs/getting_started/starter_tools/rag_cli.md)3(https://github.com/run-llama/llama_index/blob/main/docs/docs/understanding/agent/local_models.md)。

  • 要继续与 Dosu 交谈,请提及 @dosu。*

相关问题