textgenrnn
只需几行代码,便可轻松训练任何大小和复杂度的文本生成神经网络,或快速使用预训练模型进行文本训练。
textgenrnn 是一个基于 Keras/TensorFlow 的 Python 3 模块,用于创建 char-rnn ,具备许多酷炫功能:
- 现代神经网络架构,采用注意力加权和跳跃嵌入等新技术,加速训练并提高模型质量。
- 可以在字符级别或词级别进行文本训练和生成。
- 配置 RNN 大小、RNN 层数,以及是否使用双向 RNN。
- 在任意通用输入文本文件上进行训练,包括大文件。
- 在 GPU 上训练模型,然后使用 CPU 生成文本。
- 在 GPU 上训练时,利用强大的 CuDNN 实现 RNN,与典型的 LSTM 实现相比,大大加快训练时间。
- 使用上下文标签训练模型,在某些情况下可更快地学习并生成更好的结果。
你可以在这个 Colaboratory Notebook 中免费使用 GPU 训练任何文本文件!阅读这篇博客文章或观看此视频了解更多信息!
示例
from textgenrnn import textgenrnn
textgen = textgenrnn()
textgen.generate()
[Spoiler] 还有其他人发现这个帖子和他们的人比我真的更喜欢星球大战在火灾或健康和发布一个私人房子2016年游戏的报告在我的后院。
包含的模型可以轻松地在新文本上进行训练,并且即使在输入数据的单次通过后,也能生成适当的文本。
textgen.train_from_file('hacker_news_2000.txt', num_epochs=1)
textgen.generate()
项目状态 项目 Firefox
模型权重相对较小(磁盘上为 2 MB),可以轻松保存并加载到新的 textgenrnn 实例中。因此,你可以玩转已经在数据上训练过数百次的模型。(实际上,textgenrnn 学习得如此之好,以至于你需要显著增加温度以获得富有创造性的输出!)
textgen_2 = textgenrnn('/weights/hacker_news.hdf5')
textgen_2.generate(3, temperature=1.0)
我们为什么拿到钱“常规变更”
Urburg 到 Firefox 收购了 Nelf Multi Shamn
Kubernetes 由 Google 的 Bern
你还可以通过在任何训练函数中添加 new_model=True
,训练一个新模型,支持词级嵌入和双向 RNN 层。
交互模式
你还可以逐步参与输出展开的过程。交互模式会为你建议下一个字符/单词的前 N 个选项,并允许你选择一个。
在终端中运行 textgenrnn 时,传递 interactive=True
和 top=N
到 generate
。N 默认为 3。
from textgenrnn import textgenrnn
textgen = textgenrnn()
textgen.generate(interactive=True, top_n=5)
这可以为输出添加人性化的触感;感觉你就是作者!(参考)
用法
可以通过 pip
从 pypi 安装 textgenrnn:
pip3 install textgenrnn
对于最新的 textgenrnn,你必须拥有至少 2.1.0 版本的 TensorFlow。
你可以在 这个 Jupyter Notebook 中查看常见功能和模型配置选项的演示。
/datasets
包含使用 Hacker News/Reddit 数据训练 textgenrnn 的示例数据集。
/weights
包含上述数据集上进一步预训练的模型,这些模型可以加载到 textgenrnn 中。
/outputs
包含从上述预训练模型生成的文本示例。
神经网络架构和实现
textgenrnn 基于 Andrej Karpathy 的 char-rnn 项目,进行了几项现代优化,例如能够处理非常小的文本序列。
包含的预训练模型遵循一种受到 DeepMoji 启发的 神经网络架构。对于默认模型,textgenrnn 接受最多 40 个字符的输入,将每个字符转换为 100 维字符嵌入向量,并将这些向量输入到 128 单元长短期记忆 (LSTM) 循环层。然后将这些输出输入到另一个 128 单元 LSTM 中。所有三层都被输入到一个注意力层,以对最重要的时间特征进行加权并将其平均在一起(由于嵌入层和第一个 LSTM 层通过跳跃连接进入注意力层,模型更新可以更轻松地反向传播到它们并防止梯度消失)。该输出被映射为最多 394 个不同字符 的概率,包括大写字符、小写字符、标点符号和表情符号。(如果在新数据集上训练新模型,则上述所有数值参数都可以配置)
可选地,如果每个文本文件都提供上下文标签,则可以在上下文模式下训练模型,在这种模式下,模型在给定上下文的情况下学习文本,因此循环层学习去上下文化的语言。仅文本路径可以利用去上下文化的层;总的来说,这比单独训练模型只给定文本要快得多,并且在定量和定性模型性能上都表现得更好。
包中包含的模型权重是在数十万条来自Reddit提交的文本文档上训练的(通过BigQuery),这些文档来自各种多样化的子版块。该网络还使用了上述去上下文化方法进行训练,以提高训练性能并减轻作者偏见。
在使用textgenrnn对新文本数据集进行微调时,所有层都会重新训练。然而,由于原始预训练网络最初具有更强大的“知识”,新的textgenrnn最终可以更快、更准确地训练,并且有可能学习到原始数据集中不存在的新关系(例如,预训练字符嵌入包括了所有可能类型的现代互联网语法的字符上下文)。
此外,重新训练使用了基于动量的优化器和线性衰减的学习率,这两者都可以防止梯度爆炸,并使模型在长时间训练后发生发散的可能性大大降低。
备注
-
你不会100%获得高质量的生成文本,即使使用经过大量训练的神经网络。这也是病毒式博客文章/Twitter推文利用神经网络生成文本后通常会生成大量文本并在之后进行筛选/编辑的主要原因。
-
不同数据集之间的结果会有很大差异。由于预训练神经网络相对较小,它无法存储如博客文章中经常炫耀的RNNs那样多的数据。为了获得最佳效果,使用至少2000-5000个文档的数据集。如果数据集较小,则需要通过调用训练方法时将
num_epochs
设置得更高和/或从头开始训练新模型来进行更长时间的训练。即便如此,目前仍然没有一个好的启发式方法来确定一个“好”的模型。 -
重新训练textgenrnn不需要GPU,但在CPU上训练会花费更长的时间。如果使用GPU,建议增加
batch_size
参数以更好地利用硬件。
textgenrnn的未来计划
-
更正式的文档
-
使用tensorflow.js的基于Web的实现(由于网络规模较小,特别适合这种方式)
-
一种可视化注意力层输出的方法,以了解网络如何“学习”
-
一种模式,允许将模型架构用于聊天机器人对话(可能作为一个单独的项目发布)
-
更深入的上下文处理(位置上下文+允许多个上下文标签)
-
更大的预训练网络,可以容纳更长的字符序列并对语言有更深入的理解,从而生成更好的句子。
-
词级模型的分层softmax激活(当Keras有良好支持时)
-
FP16,用于在Volta/TPUs上进行超快训练(当Keras有良好支持时)
使用textgenrnn的文章/项目
文章
- Lifehacker: 如何训练自己的神经网络 作者:Beth Skwarecki
- New York Times: 让我们的算法为你挑选万圣节服装 作者:Janelle Shane
- CNN Business: 这个古怪的实验突显了AI的最大挑战 作者:Rachel Metz
项目
- 推文生成器 — 基于任意数量的Twitter用户,训练一个优化生成推文的神经网络
- Hacker News模拟器 — 基于textgenrnn训练了30多万条Hacker News提交内容的Twitter机器人。
- SubredditRNN — Reddit子版块,所有提交内容均来自textgenrnn机器人。
- 人类与AI合作的披萨 — 使用textgenrnn生成的披萨食谱,并在现实生活中制作。
- 桌游标题
- 视频游戏讨论论坛标题
- AI制作的蛋糕
- AI制作的饼干
- AI生成的歌曲
推文
维护者/创建者
Max Woolf (@minimaxir)
Max的开源项目由他的Patreon支持。如果你觉得这个项目对你有帮助,任何对Patreon的资金支持都会被用于创意用途。
致谢
感谢Andrej Karpathy通过博客文章循环神经网络的非凡效果提出char-rnn的最初建议。
感谢Daniel Grijalva为贡献交互模式。
许可协议
MIT
来自 DeepMoji 的注意力层代码(MIT 许可证)