DOC: 通过详细大纲控制提高长篇故事连贯性
2023年10月26日更新:请参阅 https://github.com/facebookresearch/doc-storygen-v2 获取支持使用较新聊天模型(如 LLaMA-2、ChatGPT)的代码版本。它遵循相同的高层结构,但并非在所有地方都完全相同(例如,为简化起见删除了一些部分,一些启发式检查也不再必要);我们重写的主要目标是让代码更易于使用和修改。
本仓库包含 DOC: 通过详细大纲控制提高长篇故事连贯性 (https://arxiv.org/abs/2212.10077, ACL 2023) 的代码,作者为 Kevin Yang、Dan Klein、Nanyun Peng 和 Yuandong Tian。在这个代码库中,我们提供了自动生成较长故事的说明(在我们的论文实验中平均 3500 多字)。根据人类评估者的判断,DOC 的故事在连贯性、相关性和趣味性方面都明显优于我们之前的系统 Re3 (https://github.com/yangkevin2/emnlp22-re3-story-generation)所写的故事。
安装 / 数据
(1) 安装 Python 3.8.15 和 PyTorch 1.13.1 (稍旧或稍新的版本可能也可以)。
(2) 通过 pip install -r requirements.txt
安装其余依赖。如果之后遇到与 huggingface_hub.snapshot_download
相关的崩溃,可能还需要运行 pip install -U sentence-transformers
。如果遇到 numpy 版本问题,请尝试 1.22.4 版本。
(3) 通过 pip install -e .
安装此仓库。
另外,在终端中运行 export OPENAI_API_KEY=$YOUR_API_KEY
,以便代码可以使用你的密钥调用 GPT3 API。
同时,运行 wget https://doc-story-generation-data.s3.amazonaws.com/doc_data.zip
并解压到此仓库的顶层目录。该文件夹包含预训练的控制器/重排器检查点和用于训练控制器/重排器的数据。
要获取我们主要实验的最终生成故事和 Surge AI 标注结果,请运行 wget https://doc-story-generation-data.s3.amazonaws.com/doc_outputs.zip
(注意:一些生成的故事可能包含敏感/不适内容,因为我们没有尝试过滤这些内容)。
计划 + 大纲生成
我们首先生成计划/大纲,然后再进行主要故事的创作。
计划 + 大纲生成命令
与我们主要论文实验中使用的设置相匹配的计划生成命令示例:
mkdir output
CUDA_VISIBLE_DEVICES=0 python -u scripts/main.py --controller none none none longformer_classifier --loader none none none order --controller-load-dir none none none doc_data/ckpt/outline_order_reranker --controller-model-string none none none roberta-large --no-editor --setup-only --outline-levels 3 --save-outline-file output/plan.pkl --log-file output/plan.log
代码假设大纲顺序重排器位于参数的第 4 个位置,所以不要更改命令的这些部分。 如果看到一些错误被打印出来不用担心,只要程序没有提前终止就行;一些部分可能需要多次尝试。
此命令使用我们下载中包含的现有重排器检查点。如果你想使用自己的检查点,请参阅下面的训练说明,并更改此命令中的路径以指向正确的检查点。
使用这些设置生成计划在 GPT3 上花费几美元。
其他参数
计划生成参数在 scripts/main.py
中编译;按照那里的链接可以看到完整列表。一些特别感兴趣的参数:
- 指定
--premise
参数来指定你自己的故事前提,而不是让 GPT3 自动生成一个。 - 更改
--outline-levels
以改变大纲的最大深度。 - 将
--outline-char-model-string
设置为不同的 InstructGPT3 模型(例如text-curie-001
)可以节省相当大一部分(如果不是大部分)GPT3 成本,但会稍微降低为大纲检测角色时的性能。 - 使用
--outline-restart-pkl
从之前生成的较低深度 pkl 文件继续生成。(我们在人机交互实验中使用此功能。) - 将
--log-level
设置为 21 到 25 之间的值以改变日志的详细程度(数字越大越简洁;默认为 25)。
主要故事生成
按照之前的说明生成计划后,我们可以生成故事。
OPT-175B 设置
我们的主要故事生成使用 Alpa (https://alpa.ai/) 提供服务的 OPT-175B,因为它允许对 token 级别的 logit 进行修改,以运行像论文中描述的 DOC 详细控制器这样的受控生成方法。 你有几个选择。
1. 免费公共 Alpa OPT-175B API (易于上手;高质量,可能较慢)
你可以在 https://opt.alpa.ai/ 向 Alpa 团队索要一个密钥来调用他们的免费公共 API(页面底部有 Slack 链接)。他们非常友好。
根据你的物理位置,这个选项可能会较慢(在运行时间上),因为他们的服务器在中东。
(我们需要访问 logprobs
端点,而不是默认的 completions
端点。)
一旦你有了密钥,你可以在下面的主要故事命令中指定 --alpa-url https://opt.alpa.ai --alpa-key YOUR_KEY
。
2. 自助服务 (更快 + 高质量;需要大量计算资源)
如果你有足够的计算资源,可以向 Meta 申请权重 (https://forms.gle/BDB2i44QwCr2mCJN6) 并使用 Alpa 自行提供服务。如果你能做到,这是最好的选择(高质量且速度合理)。
按照 https://alpa.ai/install.html 和 https://alpa.ai/tutorials/opt_serving.html 的说明进行安装和服务。 最新版本的 Alpa 应该可以工作,但我们也在 https://github.com/yangkevin2/doc-alpa 冻结了我们使用的版本,以防有用。
设置完成后,在下面的主要故事命令中指定 --alpa-url YOUR_SERVER_URL
(例如,格式为 http://0.0.0.0:8001
)。
或者你也可以使用较小的 OPT 模型,尽管这会导致明显较差的质量。
3. GPT3-175B (最易上手且最快;质量较差)
直接使用 GPT3-175B,这意味着关闭我们的详细控制器。你在平均情况下会得到明显较差的计划/大纲忠实度,但速度会快很多。
要做到这一点,在下面的主要故事命令中设置 --extension-method gpt3
。这将使用基本的 davinci
模型(即不是指令微调的 GPT3.5/GPT4 模型,它们使用不同的提示接口,目前不受支持;这些指令微调模型也经常以somewhat不同的风格写作)。
就 GPT3 API 而言,这并不太昂贵;在整个故事过程中,你可能花费不到一美元。
主要故事生成命令
设置好 OPT-175B (或其他) 服务器后,运行以下命令以使用与我们主要论文实验相同的设置来起草故事,确保附加上面描述的额外 Alpa 相关(或其他)参数。
CUDA_VISIBLE_DEVICES=0 python -u scripts/main.py {{{ALPA_ARGS}}} --controller longformer_classifier longformer_classifier fudge_controller --loader alignment coherence fine_coherence --controller-load-dir doc_data/ckpt/relevance_reranker doc_data/ckpt/coherence_reranker doc_data/ckpt/detailed_controller --controller-model-string allenai/longformer-base-4096 allenai/longformer-base-4096 facebook/opt-350m --load-outline-file output/plan.pkl --no-editor --include-future-context --control-strength 1 1 0 --control-strength-substep-increment 3 --save-complete-file output/story.pkl --log-file output/story.log
该命令假设所有 3 个重排器/控制器都按指定顺序存在,所以不要更改这些参数。
此命令使用我们下载中包含的现有重排器检查点。如果你想使用自己的检查点,请参阅下面的训练说明,并更改此命令中的路径以指向正确的检查点。
虽然此命令仍使用 GPT3 编写一些摘要用于提示,但成本仅为几美分。
其他参数
主要故事生成参数也在 scripts/main.py
中编译;按照那里的链接可以看到完整列表。一些特别感兴趣的参数:
- 更改
--max-continuation-substeps
(默认为8)和--max-tokens
(默认为64)可以改变为大纲中每个编号项目写入的最大故事文本量。使用默认设置,它将为每个项目写入最多八个64个token的段落。 - 更改
--early-stop-threshold
和--skip-threshold
可以调整提前停止草稿并转到下一个大纲项目的启发式方法。--early-stop-threshold
的值越小(更负),提前停止会更激进。--skip-threshold
的值越大(更正),当所有生成的段落候选都不太好时,直接跳到下一个大纲项目的频率会更高。 --control-strength
包含三个数字,分别对应相关性重排器、连贯性重排器和详细控制器。详细控制器的控制强度会根据--control-strength-substep-increment
逐步增加,最高到--max-control-strength
,在为给定大纲项目草拟时重置。我们认为当前设置在控制和让模型发挥创意之间取得了合理平衡,但您可以自由调整。要关闭详细控制器,只需使用--control-strength-substep-increment 0
。- 生成过程中的频率和提示重复惩罚设置为1(每个token有0.98的指数衰减)。您可以分别更改
--summarizer-frequency-penalty
、--summarizer-prompt-penalty
和--summarizer-frequency-penalty-decay
。与基础生成器相关的其他参数在story-generation/common/summarizer/summarizer_util.py
中。 - 如果您有高深度的大纲,但想使用较低深度生成(例如,将深度为3的大纲转换为深度为2的大纲),请指定
--generation-outline-levels
。 - 增加
--max-beam-size
(默认为1)以开启基于重排器的段落级可变大小束搜索程序。这在论文实验中是关闭的(会使系统速度变慢几倍)。 - 如果您的GPU内存不足,可以尝试减小
--fudge-batch-size
至32(或更小),或按照README底部的说明重新训练较小的重排器/控制器。 - 移除
--no-editor
以开启从Re3继承的编辑模块(未经大量测试;DOC在我们的主要实验中未使用)。 - 将
--log-level
设置为21到25之间的值以调整日志的详细程度(数值越高越简洁;默认为24)。
关于崩溃的说明
使用非常小的OPT模型可能导致崩溃,因为我们没有广泛测试所有生成的续写被我们的过滤器拒绝的边缘情况(您可以设置 --skip-threshold -10000
来避免这种情况发生)。当详细控制器关闭时,使用GPT3也可能偶尔发生这种情况。在我们使用OPT-175B的主要实验中从未发生过这种崩溃。
基线
基线假设您已经使用我们的代码按照之前描述的命令生成了计划。
Re3 (OPT-175B,与DOC匹配长度)
使用我们的代码生成的计划(在以下命令中为 output/plan.pkl
),并仅保存设置/角色和顶级大纲以供Re3使用:
python scripts/data/save_re3_plan.py -i output/plan.pkl -o output/re3_plan.pkl
然后按照 https://github.com/yangkevin2/emnlp22-re3-story-generation 中的说明操作。使用OPT-175B进行公平比较,使用 --extension-method opt
;Alpa参数与本仓库中相同。使用 --load-outline-file
指定已生成的计划文件。您还需要设置 --max-candidates 8 --summarizer-frequency-penalty 1 --summarizer-prompt-penalty 1
以及 --max-continuation-substeps 5
以大致匹配我们在主要实验中生成的故事长度。
滚动窗口 OPT-175B
python -u scripts/rolling_baselines.py {{{ALPA_ARGS}}} --load-outline-file output/plan.pkl --extension-method opt --save-complete-file output/rolling_opt_story.pkl > output/rolling_opt_story.log
滚动窗口 GPT3-175B
python -u scripts/rolling_baselines.py --load-outline-file output/plan.pkl --extension-method gpt3 --save-complete-file output/rolling_gpt3_story.pkl > output/rolling_gpt3_story.log
控制器 / 重排器训练
详细控制器训练
设置检查点保存目录并运行以下命令。训练数据(源自InstructGPT-13B对WritingPrompts(Fan等人2018)段落的摘要)在数据下载中提供。
CUDA_VISIBLE_DEVICES=0 python scripts/training/train_controller.py --controller-save-dir {{{SAVE_DIRECTORY}}} --controller fudge_controller --controller-model-string facebook/opt-350m --data-dir doc_data/training_data/detailed_controller_training_data.csv --dataset alignment --loader fine_coherence --batch-size 2 --lower-length-limit 1000 --controller-epochs 20 --num-workers 8 --controller-num-negatives 3 --controller-lr 1e-6 --coherence-negative-categories other shuffle repeat --limit 100000
大纲顺序重排器训练
设置检查点保存目录并运行以下命令。训练数据(一些由InstructGPT3-175B生成的非常简短的、类似大纲的故事)在数据下载中提供。
CUDA_VISIBLE_DEVICES=0 python scripts/training/train_controller.py --controller-save-dir {{{SAVE_DIRECTORY}}} --controller longformer_classifier --controller-model-string roberta-large --data-dir doc_data/training_data/order_training_data.csv --dataset csv --csv-column story --loader order --batch-size 64 --controller-epochs 20 --controller-lr 1e-5 --limit 100000 --num-workers 8
相关性和连贯性重排器训练
如果您想自己重新训练相关性和连贯性重排器,请按照 https://github.com/yangkevin2/emnlp22-re3-story-generation 中的说明操作,因为我们的重排器与他们的保持不变。