Cadet-Tiny项目介绍
项目背景
Cadet-Tiny是受Allen AI的Cosmo-XL模型启发而开发的一款非常小型的对话模型。该模型利用SODA数据集进行训练,目的是为了在边缘设备上实现推理,比如只有2GB RAM的树莓派设备。
模型特性
Cadet-Tiny基于Google的t5-small预训练模型,其大小仅为Cosmo-3B模型的约2%。这样使得Cadet-Tiny在占用极少资源的情况下,能够执行基本的对话生成任务。这个项目是开发者首次构建的SEQ2SEQ自然语言处理模型,并分享到HuggingFace平台供大家使用和改进。
开发者的联系方式
如果对该项目有任何疑问或改进建议,可以通过以下邮箱联系开发者:tcgoldfarb@gmail.com。
Google Colab 练习链接
开发者提供了一个Google Colab文件链接,通过这个链接可以了解模型的训练过程及使用AI2提供的SODA数据集的方法。链接地址:查看Google Colab
如何开始使用Cadet-Tiny
以下是使用Cadet-Tiny的代码示例:
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import colorful as cf
cf.use_true_colors()
cf.use_style('monokai')
class CadetTinyAgent:
def __init__(self):
print(cf.bold | cf.purple("Waking up Cadet-Tiny..."))
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.tokenizer = AutoTokenizer.from_pretrained("t5-small", model_max_length=512)
self.model = AutoModelForSeq2SeqLM.from_pretrained("ToddGoldfarb/Cadet-Tiny", low_cpu_mem_usage=True).to(self.device)
self.conversation_history = ""
def observe(self, observation):
self.conversation_history = self.conversation_history + observation
if len(self.conversation_history) > 400:
self.conversation_history = self.conversation_history[112:]
def set_input(self, situation_narrative="", role_instruction=""):
input_text = "dialogue: "
if situation_narrative:
input_text += situation_narrative
if role_instruction:
input_text += " <SEP> " + role_instruction
input_text += " <TURN> " + self.conversation_history
return input_text
def generate(self, situation_narrative, role_instruction, user_response):
user_response += " <TURN> "
self.observe(user_response)
input_text = self.set_input(situation_narrative, role_instruction)
inputs = self.tokenizer([input_text], return_tensors="pt").to(self.device)
outputs = self.model.generate(inputs["input_ids"], max_new_tokens=512, temperature=0.75, top_p=.95, do_sample=True)
cadet_response = self.tokenizer.decode(outputs[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
self.observe(cadet_response + " <TURN> ")
return cadet_response
def reset_history(self):
self.conversation_history = []
def run(self):
def get_valid_input(prompt, default):
while True:
user_input = input(prompt)
if user_input in ["Y", "N", "y", "n"]:
return user_input
if user_input == "":
return default
while True:
continue_chat = ""
situation_narrative = "Imagine you are Cadet-Tiny talking to ???."
role_instruction = "You are Cadet-Tiny, and you are talking to ???."
self.chat(situation_narrative, role_instruction)
continue_chat = get_valid_input(cf.purple("Start a new conversation with new setup? [Y/N]:"), "Y")
if continue_chat in ["N", "n"]:
break
print(cf.blue("CT: See you!"))
def chat(self, situation_narrative, role_instruction):
print(cf.green("Cadet-Tiny is running! Input [RESET] to reset the conversation history and [END] to end the conversation."))
while True:
user_input = input("You: ")
if user_input == "[RESET]":
self.reset_history()
print(cf.green("[Conversation history cleared. Chat with Cadet-Tiny!]"))
continue
if user_input == "[END]":
break
response = self.generate(situation_narrative, role_instruction, user_input)
print(cf.blue("CT: " + response))
def main():
print(cf.bold | cf.blue("LOADING MODEL"))
CadetTiny = CadetTinyAgent()
CadetTiny.run()
if __name__ == '__main__':
main()
致谢与引用
特别感谢Hyunwoo Kim在使用SODA数据集的讨论中给予的帮助。建议大家阅读有关SODA、Prosocial-Dialog或COSMO的研究,同时查看SODA的相关论文。
@article{kim2022soda, title={SODA: Million-scale Dialogue Distillation with Social Commonsense Contextualization}, author={Hyunwoo Kim and Jack Hessel and Liwei Jiang and Peter West and Ximing Lu and Youngjae Yu and Pei Zhou and Ronan Le Bras and Malihe Alikhani and Gunhee Kim and Maarten Sap and Yejin Choi}, journal={ArXiv}, year={2022}, volume={abs/2212.10465} }