gpt2-client(已存档)
易于使用的GPT-2 124M、345M、774M和1.5B Transformer模型封装
由Rishabh Anand创建 • https://rish-16.github.io
概述
GPT-2是一个由OpenAI开发的用于文本生成的自然语言处理模型。它是GPT(生成性预训练Transformer)模型的继承者,训练数据包括来自互联网的40GB文本。它采用了2017年Attention Is All You Need 论文中提出的Transformer模型。该模型有四个版本 - 124M
、345M
、774M
和1558M
,它们在训练数据量和参数数量上有所不同。
目前,1.5B模型是OpenAI发布的最大的可用模型。
最后,gpt2-client
是原始gpt-2
仓库的封装,它具备相同的功能但有更高的可访问性、可理解性和实用性。您可以在不到五行代码中使用全部四个GPT-2模型。
*注意:此客户封装不对直接或间接造成的任何损害负责。模型引用的任何姓名、地点和对象均为虚构,与实际实体或组织无关。样本未经过滤,可能包含冒犯性内容。请用户酌情使用。*
安装
通过pip
安装客户端。理想情况下,gpt2-client
很好地支持Python >= 3.5和TensorFlow >= 1.X。如果使用_Python 2.X_,可能需要使用--upgrade
标志通过pip
重新安装或升级一些库。
pip install gpt2-client
注意:
gpt2-client
不兼容TensorFlow 2.0,请尝试TensorFlow 1.14.0
快速上手
1. 下载模型权重和检查点
from gpt2_client import GPT2Client
gpt2 = GPT2Client('124M') # 这也可以是`355M`、`774M`或`1558M`。重命名`save_dir`为任何名称。
gpt2.load_model(force_download=False) # 如果可用,使用缓存版本。
这将在当前工作目录中创建一个名为models
的目录,并下载模型所需的权重、检查点、模型JSON和超参数。调用load_model()
函数后,如果models
目录中的文件已下载完毕,则无需再次调用。
注意: 设置
force_download=True
以覆盖现有的缓存模型权重和检查点
2. 开始生成文本!
from gpt2_client import GPT2Client
gpt2 = GPT2Client('124M') # 这也可以是`355M`、`774M`或`1558M`
gpt2.load_model()
gpt2.generate(interactive=True) # 提示用户输入
gpt2.generate(n_samples=4) # 生成4段文本
text = gpt2.generate(return_text=True) # 生成文本并以数组形式返回
gpt2.generate(interactive=True, n_samples=3) # 每次不同的提示
从上述示例可以看出,生成选项非常灵活。您可以根据需要生成多段文本或根据提示一次生成一段。
3. 从提示批量生成文本
from gpt2_client import GPT2Client
gpt2 = GPT2Client('124M') # 这也可以是`355M`、`774M`或`1558M`
gpt2.load_model()
prompts = [
"这是一个提示1",
"这是一个提示2",
"这是一个提示3",
"这是一个提示4"
]
text = gpt2.generate_batch_from_prompts(prompts) # 返回一个生成文本数组
4. 微调GPT-2到自定义数据集
from gpt2_client import GPT2Client
gpt2 = GPT2Client('124M') # 这也可以是`355M`、`774M`或`1558M`
gpt2.load_model()
my_corpus = './data/shakespeare.txt' # 语料库路径
custom_text = gpt2.finetune(my_corpus, return_text=True) # 加载您的自定义数据集
为了将GPT-2微调到您的自定义语料库或数据集,最好手头有GPU或TPU。Google Colab是一个可以用来重新训练/微调模型的工具。
5. 编码和解码文本序列
from gpt2_client import GPT2Client
gpt2 = GPT2Client('124M') # 这也可以是`355M`、`774M`或`1558M`
gpt2.load_model()
# 编码一个句子
encs = gpt2.encode_seq("Hello world, this is a sentence")
# [15496, 995, 11, 428, 318, 257, 6827]
# 解码一个编码序列
decs = gpt2.decode_seq(encs)
# Hello world, this is a sentence
贡献
欢迎提出建议、改进和增强!如果您有任何问题,请在问题部分提出。如果您有改进建议,请在创建PR之前先提一个议题讨论。
所有的想法——无论多么离谱——都欢迎!
捐赠
开源真的很有趣。您的捐赠激励我带来新的想法。如果您愿意支持我的开源工作,请捐赠——这对我意义重大!