Modal Finetune SQL: 使用LlamaIndex微调Llama 2实现高效文本到SQL转换
在当今数据驱动的时代,能够快速准确地将自然语言查询转换为SQL语句是一项极其有价值的技能。然而,即使是像Llama 2这样强大的大语言模型,在这方面的表现也并非尽如人意。本文将为您详细介绍如何利用Modal和LlamaIndex对Llama 2进行微调,显著提升其文本到SQL的转换能力。
项目背景与意义
Llama 2作为开源大语言模型的佼佼者,在多个基准测试中展现出了接近甚至超越GPT-3.5的性能。这使得它成为构建复杂语言模型应用的理想选择。然而,在文本到SQL这一特定任务上,Llama 2 7B参数版本的表现却不尽如人意。
为了说明这一问题,我们可以看一个简单的例子。给定以下提示:
您是一个强大的文本到SQL模型。您的任务是回答有关数据库的问题。您会得到一个问题和一个或多个表的上下文信息。
您必须输出能回答该问题的SQL查询。
### 输入:
在1981年哪个队伍总体排名第148?
### 上下文:
CREATE TABLE table_name_8 (team VARCHAR, year VARCHAR, overall_pick VARCHAR)
### 回应:
Llama 2生成的输出是:
SELECT * FROM `table_name_8` WHERE '1980' = YEAR AND TEAM = "Boston Celtics" ORDER BY OVERALL_PICK DESC LIMIT 1;
而正确的输出应该是:
SELECT team FROM table_name_8 WHERE year = 1981 AND overall_pick = "148"
这个例子清楚地展示了Llama 2在生成格式正确且准确的SQL输出方面的不足。
正是在这样的背景下,微调(fine-tuning)技术显示出了其重要性。通过使用适当的文本到SQL数据集对Llama 2进行微调,我们可以显著提高其在这一特定任务上的表现。微调过程实质上是修改模型的权重,可以针对全部参数,部分参数,甚至仅针对额外参数(如LoRA技术所示)。
本教程旨在展示如何对Llama 2进行微调,并将其无缝集成到下游的LLM应用中。相比于仅关注"上下文学习"和"检索增强"的教程,本教程更进一步,涉及了模型本身的修改。虽然微调过程可能存在较高的学习曲线和计算需求,但本教程将尽可能简化这一过程,使您能够轻松上手。
技术栈概览
本教程使用的主要技术栈包括:
- b-mc2/sql-create-context: 来自Hugging Face的训练数据集
- OpenLLaMa
open_llama_7b_v2
: 作为基础模型 - PEFT: 用于高效微调
- Modal: 处理微调过程中的云计算和编排
- LlamaIndex: 用于对任意SQL数据库进行文本到SQL推理
特别感谢Anyscale的Llama 2教程,它为本项目提供了重要灵感。
所有相关材料都可以在我们的GitHub仓库中找到: https://github.com/run-llama/modal_finetune_sql。完整的教程可以在Jupyter笔记本指南中查看。
微调过程详解
步骤1: 加载训练数据
首先,我们使用Modal加载b-mc2/sql-create-context
数据集。这个任务相对简单,主要是将数据集加载并格式化为.jsonl
文件。
modal run src.load_data_sql --data-dir "data_sql"
在后台,这个任务的核心逻辑如下:
@stub.function(
retries=Retries(
max_retries=3,
initial_delay=5.0,
backoff_coefficient=2.0,
),
timeout=60 * 60 * 2,
network_file_systems={VOL_MOUNT_PATH.as_posix(): output_vol},
cloud="gcp",
)
def load_data_sql(data_dir: str = "data_sql"):
from datasets import load_dataset
dataset = load_dataset("b-mc2/sql-create-context")
dataset_splits = {"train": dataset["train"]}
out_path = get_data_path(data_dir)
out_path.parent.mkdir(parents=True, exist_ok=True)
for key, ds in dataset_splits.items():
with open(out_path, "w") as f:
for item in ds:
newitem = {
"input": item["question"],
"context": item["context"],
"output": item["answer"],
}
f.write(json.dumps(newitem) + "\n")
步骤2: 运行微调脚本
接下来,我们在解析后的数据集上运行微调脚本:
modal run src.finetune_sql --data-dir "data_sql" --model-dir "model_sql"
微调脚本执行以下关键步骤:
- 将数据集分割为训练集和验证集:
train_val = data["train"].train_test_split(test_size=val_set_size, shuffle=True, seed=42)
train_data = train_val["train"].shuffle().map(generate_and_tokenize_prompt)
val_data = train_val["test"].shuffle().map(generate_and_tokenize_prompt)
- 将每个分割格式化为(输入提示, 标签)的元组:
def generate_and_tokenize_prompt(data_point):
full_prompt = generate_prompt_sql(
data_point["input"],
data_point["context"],
data_point["output"],
)
tokenized_full_prompt = tokenize(full_prompt)
if not train_on_inputs:
raise NotImplementedError("not implemented yet")
return tokenized_full_prompt
输入提示的格式与本文开头给出的例子相同。
当微调脚本运行完成后,模型将被保存在指定的远程云目录中(由model_dir
参数指定,如果未指定则使用默认值)。
步骤3: 评估
现在,我们的模型已经完成微调,并可以从云端服务。我们可以使用sql-create-context
中的样本数据进行一些基本评估,比较微调后的模型与基线Llama 2模型的性能:
modal run src.eval_sql::main
结果显示,微调后的模型性能有了显著提升:
输入1: {'input': '哪个地区(年份)的Abigail排第7, Sophia排第1, Aaliyah排第5?', 'context': 'CREATE TABLE table_name_12 (region__year_ VARCHAR, no_5 VARCHAR, no_7 VARCHAR, no_1 VARCHAR)', 'output': 'SELECT region__year_ FROM table_name_12 WHERE no_7 = "abigail" AND no_1 = "sophia" AND no_5 = "aaliyah"'}
输出1 (微调后的模型): SELECT region__year_ FROM table_name_12 WHERE no_7 = "abigail" AND no_1 = "sophia" AND no_5 = "aaliyah"
输出1 (基础模型): SELECT * FROM table_name_12 WHERE region__year = '2018' AND no_5 = 'Abigail' AND no_7 = 'Sophia' AND no_1 = 'Aaliyah';
输入2: {'input': '列出54741的结果/比赛', 'context': 'CREATE TABLE table_21436373_11 (result_games VARCHAR, attendance VARCHAR)', 'output': 'SELECT result_games FROM table_21436373_11 WHERE attendance = 54741'}
输出2 (微调后的模型): SELECT result_games FROM table_21436373_11 WHERE attendance = "54741"
输出2 (基础模型): SELECT * FROM table_21436373_11 WHERE result_games = 'name' AND attendance > 0;
可以看到,基础模型产生了格式错误或不正确的SQL语句,而微调后的模型能够生成更接近预期输出的结果。
步骤4: 将微调后的模型与LlamaIndex集成
最后,我们可以在LlamaIndex中使用这个微调后的模型,对任意数据库进行文本到SQL的转换。
首先,我们定义一个测试用的SQL数据库:
db_file = "cities.db"
engine = create_engine(f"sqlite:///{db_file}")
metadata_obj = MetaData()
# 创建城市SQL表
table_name = "city_stats"
city_stats_table = Table(
table_name,
metadata_obj,
Column("city_name", String(16), primary_key=True),
Column("population", Integer),
Column("country", String(16), nullable=False),
)
metadata_obj.create_all(engine)
这将创建一个名为city_stats
的表,包含城市名称、人口和国家信息。我们将其存储在cities.db
文件中。
然后,我们可以使用Modal加载微调后的模型和这个数据库文件到LlamaIndex的NLSQLTableQueryEngine
中:
modal run src.inference_sql_llamaindex::main --query "哪个城市人口最多?" --sqlite-file-path "nbs/cities.db" --model-dir "model_sql" --use-finetuned-model True
我们会得到类似以下的响应:
SQL查询: SELECT MAX(population) FROM city_stats WHERE country = "United States"
响应: [(2679000,)]
结论
通过本教程,我们展示了如何使用Modal和LlamaIndex对Llama 2模型进行微调,以提高其文本到SQL的转换能力。这个过程虽然涉及多个步骤,但每个步骤都经过精心设计,使得即使对于初学者来说也相对容易上手。
微调后的模型不仅在评估阶段表现出色,还可以无缝集成到LlamaIndex中,用于实际的数据库查询任务。这为构建更高效、更准确的自然语言数据库接口开辟了新的可能性。
资源汇总
为方便读者进一步探索,我们再次列出本项目使用的主要资源:
我们希望这个教程能够帮助您更好地理解和应用大语言模型的微调技术,特别是在文本到SQL这样的特定任务上。随着技术的不断发展,我们期待看到更多创新的应用在这个领域涌现。
🚀 如果您对这个项目感兴趣,不妨亲自尝试一下,相信您会发现许多令人兴奋的可能性。祝您在探索AI和数据分析的道路上一切顺利!