Modal Finetune SQL: 使用LlamaIndex微调Llama 2实现高效文本到SQL转换

Ray

modal_finetune_sql

Modal Finetune SQL: 使用LlamaIndex微调Llama 2实现高效文本到SQL转换

在当今数据驱动的时代,能够快速准确地将自然语言查询转换为SQL语句是一项极其有价值的技能。然而,即使是像Llama 2这样强大的大语言模型,在这方面的表现也并非尽如人意。本文将为您详细介绍如何利用Modal和LlamaIndex对Llama 2进行微调,显著提升其文本到SQL的转换能力。

项目背景与意义

Llama 2作为开源大语言模型的佼佼者,在多个基准测试中展现出了接近甚至超越GPT-3.5的性能。这使得它成为构建复杂语言模型应用的理想选择。然而,在文本到SQL这一特定任务上,Llama 2 7B参数版本的表现却不尽如人意。

为了说明这一问题,我们可以看一个简单的例子。给定以下提示:

您是一个强大的文本到SQL模型。您的任务是回答有关数据库的问题。您会得到一个问题和一个或多个表的上下文信息。

您必须输出能回答该问题的SQL查询。

### 输入:
在1981年哪个队伍总体排名第148?

### 上下文:
CREATE TABLE table_name_8 (team VARCHAR, year VARCHAR, overall_pick VARCHAR)

### 回应:

Llama 2生成的输出是:

SELECT * FROM `table_name_8` WHERE '1980' = YEAR AND TEAM = "Boston Celtics" ORDER BY OVERALL_PICK DESC LIMIT 1;

而正确的输出应该是:

SELECT team FROM table_name_8 WHERE year = 1981 AND overall_pick = "148"

这个例子清楚地展示了Llama 2在生成格式正确且准确的SQL输出方面的不足。

正是在这样的背景下,微调(fine-tuning)技术显示出了其重要性。通过使用适当的文本到SQL数据集对Llama 2进行微调,我们可以显著提高其在这一特定任务上的表现。微调过程实质上是修改模型的权重,可以针对全部参数,部分参数,甚至仅针对额外参数(如LoRA技术所示)。

本教程旨在展示如何对Llama 2进行微调,并将其无缝集成到下游的LLM应用中。相比于仅关注"上下文学习"和"检索增强"的教程,本教程更进一步,涉及了模型本身的修改。虽然微调过程可能存在较高的学习曲线和计算需求,但本教程将尽可能简化这一过程,使您能够轻松上手。

技术栈概览

本教程使用的主要技术栈包括:

特别感谢Anyscale的Llama 2教程,它为本项目提供了重要灵感。

所有相关材料都可以在我们的GitHub仓库中找到: https://github.com/run-llama/modal_finetune_sql。完整的教程可以在Jupyter笔记本指南中查看。

微调过程详解

步骤1: 加载训练数据

首先,我们使用Modal加载b-mc2/sql-create-context数据集。这个任务相对简单,主要是将数据集加载并格式化为.jsonl文件。

modal run src.load_data_sql --data-dir "data_sql"

在后台,这个任务的核心逻辑如下:

@stub.function(
    retries=Retries(
        max_retries=3,
        initial_delay=5.0,
        backoff_coefficient=2.0,
    ),
    timeout=60 * 60 * 2,
    network_file_systems={VOL_MOUNT_PATH.as_posix(): output_vol},
    cloud="gcp",
)
def load_data_sql(data_dir: str = "data_sql"):
    from datasets import load_dataset

    dataset = load_dataset("b-mc2/sql-create-context")

    dataset_splits = {"train": dataset["train"]}
    out_path = get_data_path(data_dir)

    out_path.parent.mkdir(parents=True, exist_ok=True)

    for key, ds in dataset_splits.items():
        with open(out_path, "w") as f:
            for item in ds:
                newitem = {
                    "input": item["question"],
                    "context": item["context"],
                    "output": item["answer"],
                }
                f.write(json.dumps(newitem) + "\n")

步骤2: 运行微调脚本

接下来,我们在解析后的数据集上运行微调脚本:

modal run src.finetune_sql --data-dir "data_sql" --model-dir "model_sql"

微调脚本执行以下关键步骤:

  1. 将数据集分割为训练集和验证集:
train_val = data["train"].train_test_split(test_size=val_set_size, shuffle=True, seed=42)
train_data = train_val["train"].shuffle().map(generate_and_tokenize_prompt)
val_data = train_val["test"].shuffle().map(generate_and_tokenize_prompt)
  1. 将每个分割格式化为(输入提示, 标签)的元组:
def generate_and_tokenize_prompt(data_point):
  full_prompt = generate_prompt_sql(
      data_point["input"],
      data_point["context"],
      data_point["output"],
  )
  tokenized_full_prompt = tokenize(full_prompt)
  if not train_on_inputs:
      raise NotImplementedError("not implemented yet")
  return tokenized_full_prompt

输入提示的格式与本文开头给出的例子相同。

当微调脚本运行完成后,模型将被保存在指定的远程云目录中(由model_dir参数指定,如果未指定则使用默认值)。

步骤3: 评估

现在,我们的模型已经完成微调,并可以从云端服务。我们可以使用sql-create-context中的样本数据进行一些基本评估,比较微调后的模型与基线Llama 2模型的性能:

modal run src.eval_sql::main

结果显示,微调后的模型性能有了显著提升:

输入1: {'input': '哪个地区(年份)的Abigail排第7, Sophia排第1, Aaliyah排第5?', 'context': 'CREATE TABLE table_name_12 (region__year_ VARCHAR, no_5 VARCHAR, no_7 VARCHAR, no_1 VARCHAR)', 'output': 'SELECT region__year_ FROM table_name_12 WHERE no_7 = "abigail" AND no_1 = "sophia" AND no_5 = "aaliyah"'}
输出1 (微调后的模型): SELECT region__year_ FROM table_name_12 WHERE no_7 = "abigail" AND no_1 = "sophia" AND no_5 = "aaliyah"
输出1 (基础模型): SELECT * FROM table_name_12 WHERE region__year = '2018' AND no_5 = 'Abigail' AND no_7 = 'Sophia' AND no_1 = 'Aaliyah';

输入2: {'input': '列出54741的结果/比赛', 'context': 'CREATE TABLE table_21436373_11 (result_games VARCHAR, attendance VARCHAR)', 'output': 'SELECT result_games FROM table_21436373_11 WHERE attendance = 54741'}
输出2 (微调后的模型): SELECT result_games FROM table_21436373_11 WHERE attendance = "54741"
输出2 (基础模型): SELECT * FROM table_21436373_11 WHERE result_games = 'name' AND attendance > 0;

可以看到,基础模型产生了格式错误或不正确的SQL语句,而微调后的模型能够生成更接近预期输出的结果。

步骤4: 将微调后的模型与LlamaIndex集成

最后,我们可以在LlamaIndex中使用这个微调后的模型,对任意数据库进行文本到SQL的转换。

首先,我们定义一个测试用的SQL数据库:

db_file = "cities.db"
engine = create_engine(f"sqlite:///{db_file}")
metadata_obj = MetaData()
# 创建城市SQL表
table_name = "city_stats"
city_stats_table = Table(
    table_name,
    metadata_obj,
    Column("city_name", String(16), primary_key=True),
    Column("population", Integer),
    Column("country", String(16), nullable=False),
)
metadata_obj.create_all(engine)

这将创建一个名为city_stats的表,包含城市名称、人口和国家信息。我们将其存储在cities.db文件中。

然后,我们可以使用Modal加载微调后的模型和这个数据库文件到LlamaIndex的NLSQLTableQueryEngine中:

modal run src.inference_sql_llamaindex::main --query "哪个城市人口最多?" --sqlite-file-path "nbs/cities.db" --model-dir "model_sql" --use-finetuned-model True

我们会得到类似以下的响应:

SQL查询: SELECT MAX(population) FROM city_stats WHERE country = "United States"
响应: [(2679000,)]

结论

通过本教程,我们展示了如何使用Modal和LlamaIndex对Llama 2模型进行微调,以提高其文本到SQL的转换能力。这个过程虽然涉及多个步骤,但每个步骤都经过精心设计,使得即使对于初学者来说也相对容易上手。

微调后的模型不仅在评估阶段表现出色,还可以无缝集成到LlamaIndex中,用于实际的数据库查询任务。这为构建更高效、更准确的自然语言数据库接口开辟了新的可能性。

资源汇总

为方便读者进一步探索,我们再次列出本项目使用的主要资源:

我们希望这个教程能够帮助您更好地理解和应用大语言模型的微调技术,特别是在文本到SQL这样的特定任务上。随着技术的不断发展,我们期待看到更多创新的应用在这个领域涌现。

🚀 如果您对这个项目感兴趣,不妨亲自尝试一下,相信您会发现许多令人兴奋的可能性。祝您在探索AI和数据分析的道路上一切顺利!

avatar
0
0
0
相关项目
Project Cover

llm-chain

llm-chain是一组强大的Rust库,支持创建高级LLM应用,如聊天机器人和智能代理。平台支持云端和本地LLM,提供提示模板和多步骤链功能,以处理复杂任务。还支持向量存储集成,为模型提供长期记忆和专业知识。兼容ChatGPT、LLaMa和Alpaca模型,并通过llm.rs实现Rust语言的LLM支持,无需C++依赖。

Project Cover

llama.onnx

此项目提供LLaMa-7B和RWKV-400M的ONNX模型与独立演示,无需torch或transformers,适用于2GB内存设备。项目包括内存池支持、温度与topk logits调整,并提供导出混合精度和TVM转换的详细步骤,适用于嵌入式设备和分布式系统的大语言模型部署和推理。

Project Cover

modal_finetune_sql

此项目展示了在Text-to-SQL数据集上微调LLaMa 2 7B模型的过程。利用LlamaIndex、Modal和Hugging Face datasets等工具,项目提供了从数据加载到模型微调和推理的完整教程。开发者可以学习如何针对结构化数据库执行自然语言查询,并通过提供的模型权重下载选项,快速构建Text-to-SQL应用。项目涵盖了整个开发流程,为Text-to-SQL应用的实现提供了实用的参考。

Project Cover

Skywork-Reward-Gemma-2-27B

Skywork-Reward-Gemma-2-27B是基于gemma-2-27b-it架构开发的奖励模型。该模型仅使用80K高质量偏好对数据进行训练,在数学、编程和安全等多个领域的复杂场景偏好判断中表现优异。目前在RewardBench排行榜位居榜首,证明了利用相对小规模数据集和简单数据处理技术也能构建高性能奖励模型。

Project Cover

ruadapt_llama3_instruct_lep_saiga_kto_ablitirated

ruadapt_llama3_instruct_lep_saiga_kto_ablitirated是一个基于LLaMA 3和Learned Embedding Propagation (LEP)技术的大语言模型。它通过KTO和abliteration技术,在saiga_preferences数据集上训练,支持俄语和英语。模型运用先进的分词技术优化俄语适配,为自然语言处理提供新方案。这一创新模型特别适用于需要高质量俄语理解和生成的NLP任务,如机器翻译、文本分类和问答系统等。

最新项目
Project Cover

豆包MarsCode

豆包 MarsCode 是一款革命性的编程助手,通过AI技术提供代码补全、单测生成、代码解释和智能问答等功能,支持100+编程语言,与主流编辑器无缝集成,显著提升开发效率和代码质量。

Project Cover

AI写歌

Suno AI是一个革命性的AI音乐创作平台,能在短短30秒内帮助用户创作出一首完整的歌曲。无论是寻找创作灵感还是需要快速制作音乐,Suno AI都是音乐爱好者和专业人士的理想选择。

Project Cover

白日梦AI

白日梦AI提供专注于AI视频生成的多样化功能,包括文生视频、动态画面和形象生成等,帮助用户快速上手,创造专业级内容。

Project Cover

有言AI

有言平台提供一站式AIGC视频创作解决方案,通过智能技术简化视频制作流程。无论是企业宣传还是个人分享,有言都能帮助用户快速、轻松地制作出专业级别的视频内容。

Project Cover

Kimi

Kimi AI助手提供多语言对话支持,能够阅读和理解用户上传的文件内容,解析网页信息,并结合搜索结果为用户提供详尽的答案。无论是日常咨询还是专业问题,Kimi都能以友好、专业的方式提供帮助。

Project Cover

讯飞绘镜

讯飞绘镜是一个支持从创意到完整视频创作的智能平台,用户可以快速生成视频素材并创作独特的音乐视频和故事。平台提供多样化的主题和精选作品,帮助用户探索创意灵感。

Project Cover

讯飞文书

讯飞文书依托讯飞星火大模型,为文书写作者提供从素材筹备到稿件撰写及审稿的全程支持。通过录音智记和以稿写稿等功能,满足事务性工作的高频需求,帮助撰稿人节省精力,提高效率,优化工作与生活。

Project Cover

阿里绘蛙

绘蛙是阿里巴巴集团推出的革命性AI电商营销平台。利用尖端人工智能技术,为商家提供一键生成商品图和营销文案的服务,显著提升内容创作效率和营销效果。适用于淘宝、天猫等电商平台,让商品第一时间被种草。

Project Cover

AIWritePaper论文写作

AIWritePaper论文写作是一站式AI论文写作辅助工具,简化了选题、文献检索至论文撰写的整个过程。通过简单设定,平台可快速生成高质量论文大纲和全文,配合图表、参考文献等一应俱全,同时提供开题报告和答辩PPT等增值服务,保障数据安全,有效提升写作效率和论文质量。

投诉举报邮箱: service@vectorlightyear.com
@2024 懂AI·鲁ICP备2024100362号-6·鲁公网安备37021002001498号