llama_index [问题]:如何微调用于克劳德模型的文本到SQL?

enyaitl3  于 2个月前  发布在  其他
关注(0)|答案(1)|浏览(28)

问题验证

  • 我已经在文档和discord上搜索过答案。

问题

我尝试在llamaindex中搜索关于微调文本到SQL的Claude模型的信息,但我没有找到。
你能请分享一些信息或示例吗?(例如提示、实现等)因为我使用claude-3-haiku-20240307和llamanindex。他们通常给我生成无效的SQL。

k2arahey

k2arahey1#

要使用LlamaIndex对Claude模型进行文本到SQL的微调,特别是使用claude-3-haiku-20240307版本,请按照以下步骤操作:

  1. 安装所需软件包:
%pip install llama-index-llms-gradient
%pip install llama-index-finetuning
  1. 导入必要的库:
import os
from llama_index.llms.gradient import GradientBaseModelLLM
from llama_index.finetuning import GradientFinetuneEngine
  1. 设置环境变量:
os.environ["GRADIENT_ACCESS_TOKEN"] = os.getenv("GRADIENT_API_KEY")
os.environ["GRADIENT_WORKSPACE_ID"] = ""
  1. 准备数据:
    加载并以所需格式保存数据集。
from datasets import load_dataset
from pathlib import Path
import json

def load_jsonl(data_dir):
    data_path = Path(data_dir).as_posix()
    data = load_dataset("json", data_files=data_path)
    return data

def save_jsonl(data_dicts, out_path):
    with open(out_path, "w") as fp:
        for data_dict in data_dicts:
            fp.write(json.dumps(data_dict) + "\n")

def load_data_sql(data_dir: str = "data_sql"):
    dataset = load_dataset("b-mc2/sql-create-context")
    dataset_splits = {"train": dataset["train"]}
    out_path = Path(data_dir)
    out_path.parent.mkdir(parents=True, exist_ok=True)
    for key, ds in dataset_splits.items():
        with open(out_path, "w") as f:
            for item in ds:
                newitem = {
                    "input": item["question"],
                    "context": item["context"],
                    "output": item["answer"],
                }
                f.write(json.dumps(newitem) + "\n")

load_data_sql(data_dir="data_sql")
  1. 将数据分为训练/验证集:
from math import ceil

def get_train_val_splits(data_dir: str = "data_sql", val_ratio: float = 0.1, seed: int = 42, shuffle: bool = True):
    data = load_jsonl(data_dir)
    num_samples = len(data["train"])
    val_set_size = ceil(val_ratio * num_samples)
    train_val = data["train"].train_test_split(test_size=val_set_size, shuffle=shuffle, seed=seed)
    return train_val["train"].shuffle(), train_val["test"].shuffle()

raw_train_data, raw_val_data = get_train_val_splits(data_dir="data_sql")
save_jsonl(raw_train_data, "train_data_raw.jsonl")
save_jsonl(raw_val_data, "val_data_raw.jsonl")
  1. 将训练数据Map到提示:
text_to_sql_tmpl_str = """<s>### Instruction:\n{system_message}{user_message}\n\n### Response:\n{response}</s>"""
text_to_sql_inference_tmpl_str = """<s>### Instruction:\n{system_message}{user_message}\n\n### Response:\n"""

def _generate_prompt_sql(input, context, dialect="sqlite", output=""):
    system_message = """You are a powerful text-to-SQL model. Your job is to answer questions about a database. You are given a question and context regarding one or more tables. You must output the SQL query that answers the question."""
    user_message = f"""### Dialect:\n{dialect}\n### Input:\n{input}\n### Context:\n{context}\n### Response:\n"""
    if output:
        return text_to_sql_tmpl_str.format(system_message=system_message, user_message=user_message, response=output)
    else:
        return text_to_sql_inference_tmpl_str.format(system_message=system_message, user_message=user_message)

def generate_prompt(data_point):
    full_prompt = _generate_prompt_sql(data_point["input"], data_point["context"], dialect="sqlite", output=data_point["output"])
    return {"inputs": full_prompt}

train_data = [{"inputs": d["inputs"] for d in raw_train_data.map(generate_prompt)}]
save_jsonl(train_data, "train_data.jsonl")
val_data = [{"inputs": d["inputs"] for d in raw_val_data.map(generate_prompt)}]
save_jsonl(val_data, "val_data.jsonl")
  1. 使用梯度进行微调:
base_model_slug = "claude-3-haiku-20240307"
base_llm = GradientBaseModelLLM(base_model_slug=base_model_slug, max_tokens=300)

finetune_engine = GradientFinetuneEngine(
    base_model_slug=base_model_slug,
    name="text_to_sql",
    data_path="train_data.jsonl",
    verbose=True,
    max_steps=200,
    batch_size=4,
)

epochs = 1
for i in range(epochs):
    print(f"** EPOCH {i} **")
    finetune_engine.finetune()

ft_llm = finetune_engine.get_finetuned_model(max_tokens=300)
  1. 评估:
    在验证数据和一个玩具数据集上评估微调后的模型。
def get_text2sql_completion(llm, raw_datapoint):
    text2sql_tmpl_str = _generate_prompt_sql(raw_datapoint["input"], raw_datapoint["context"], dialect="sqlite", output=None)
    response = llm.complete(text2sql_tmpl_str)
    return str(response)

test_datapoint = raw_val_data[2]
get_text2sql_completion(ft_llm, test_datapoint)

这个过程概述了如何使用LlamaIndex和Gradient对claude-3-haiku-20240307模型进行文本到SQL任务的微调。如需更详细的指导,请参阅LlamaIndex documentation on fine-tuning1(https://github.com/run-llama/llama_index/blob/main/docs/docs/examples/finetuning/gradient/gradient_text2sql.ipynb)2(https://github.com/run-llama/llama_index/blob/main/docs/docs/use_cases/fine_tuning.md)3(https://github.com/run-llama/llama_index/blob/main/docs/docs/optimizing/fine-tuning/fine-tuning.md)。

详细信息

关于Dosu

此响应旨在提供帮助并节省您的时间。它不是精确解决方案,而是您自己研究的起点。
请帮助我变得更有用!如果这有帮助,请留下👍,如果无关紧要,请留下👎。要继续对话,请提及@dosu。

相关问题