Tensorflow 项目模板
一个简单且设计良好的结构对于任何深度学习项目都是至关重要的,所以在经过大量练习和对tensorflow项目的贡献后,这里提供了一个tensorflow项目模板,它结合了简单性、最佳文件夹结构实践和良好的面向对象设计。 主要的想法是,每次启动tensorflow项目时,你都会做很多相同的事情,所以将所有这些共享的内容包装起来,将有助于你在每次启动新tensorflow项目时,只需要更改核心部分。
因此,这是一个简单的tensorflow模板,帮助你更快地进入主要项目,并专注于核心部分(模型,训练,等等)
目录
简述
简而言之,这里是如何使用此模板的示例,假设你想实现VGG模型,你应该执行以下操作:
- 在models文件夹中创建一个名为VGG的类,继承"base_model"类
class VGGModel(BaseModel):
def __init__(self, config):
super(VGGModel, self).__init__(config)
#调用build_model和init_saver函数。
self.build_model()
self.init_saver()
- 覆盖这两个函数"build_model",在其中实现vgg模型,以及"init_saver",在其中定义tensorflow saver,然后在初始化程序中调用它们。
def build_model(self):
# 这里你可以构建任何模型的tensorflow图,并定义损失。
pass
def init_saver(self):
# 这里你可以初始化用于保存检查点的tensorflow saver。
self.saver = tf.train.Saver(max_to_keep=self.config.max_to_keep)
- 在trainers文件夹中创建一个VGG训练器,继承"base_train"类
class VGGTrainer(BaseTrain):
def __init__(self, sess, model, data, config, logger):
super(VGGTrainer, self).__init__(sess, model, data, config, logger)
- 覆盖这两个函数"train_step","train_epoch",在其中编写训练过程的逻辑
def train_epoch(self):
"""
实现epoch的逻辑:
-循环配置中的迭代次数并调用训练步骤
-使用summary添加任何你想要的摘要
"""
pass
def train_step(self):
"""
实现训练步骤的逻辑
- 运行tensorflow会话
- 返回任何你需要总结的指标
"""
pass
- 在主文件中,创建会话和以下对象的实例:"Model"、"Logger"、"Data_Generator"、"Trainer"和config
sess = tf.Session()
# 创建你想要的模型实例
model = VGGModel(config)
# 创建数据生成器
data = DataGenerator(config)
# 创建tensorboard日志记录器
logger = Logger(sess, config)
- 将所有这些对象传递给训练器对象,并通过调用"trainer.train()"开始训练
trainer = VGGTrainer(sess, model, data, config, logger)
# 这里训练你的模型
trainer.train()
你会在模型和训练器文件夹中找到一个模板文件和一个简单示例,展示如何简单地尝试你的第一个模型。
详细说明
项目架构
<img align="center" 高度="600" 宽度="600" src="https://github.com/Mrgemy95/Tensorflow-Project-Templete/blob/master/figures/diagram.png?raw=true">
文件夹结构
├── base
│ ├── base_model.py - 此文件包含模型的抽象类。
│ └── base_train.py - 此文件包含训练器的抽象类。
│
│
├── model - 此文件夹包含项目的任何模型。
│ └── example_model.py
│
│
├── trainer - 此文件夹包含项目的训练器。
│ └── example_trainer.py
│
├── mains - 这里是项目的主程序(你可能需要多个主程序)。
│ └── example_main.py - 这是负责整个流程的主程序示例。
│
├── data _loader
│ └── data_generator.py - 这里是负责所有数据处理的数据生成器。
│
└── utils
├── logger.py
└── any_other_utils_you_need
主要组件
模型
-
基础模型
基础模型是一个抽象类,必须由你创建的任何模型继承,其背后的想法是,所有模型之间有很多共享的内容。 基础模型包含:
- 保存 - 这个函数用于将检查点保存到磁盘。
- 加载 - 这个函数用于从磁盘加载检查点。
- 当前epoch、全局步骤计数器 - 这些变量用于跟踪当前的epoch和全局步骤。
- 初始化Saver 一个用于初始化保存和加载检查点的saver的抽象函数,注意:在你要实现的模型中覆盖此函数。
- 构建模型 这是一个用于定义模型的抽象函数,注意:在你要实现的模型中覆盖此函数。
-
你的模型
这里是你实现模型的地方。 因此你应该:
- 创建你的模型类并继承基础模型类
- 覆盖"build_model",在其中编写你想要的tensorflow模型
- 覆盖"init_saver",在其中创建一个tensorflow saver,用于保存和加载检查点
- 在初始化程序中调用"build_model"和"init_saver"。
训练器
-
基础训练器
基础训练器是一个包装训练过程的抽象类。
-
你的训练器
这里是你在训练器中应该实现的内容。
- 创建你的训练器类并继承基础训练器类。
- 覆盖这两个函数"train_step","train_epoch",在其中实现每个步骤和每个epoch的训练过程。
数据加载器
此类负责所有数据处理和处理,并提供一个可以被训练器使用的简单接口。
日志记录器
此类负责tensorboard摘要,在你的训练器中创建一个所有你想要总结的tensorflow变量的字典,然后将此字典传递给logger.summarize()。
此类还支持报告给Comet.ml,它允许你查看所有超参数、指标、图表、依赖项等,包括实时指标。 在配置文件中添加你的API密钥:
例如:"comet_api_key": "your key here"
Comet.ml集成
此模板还支持报告给Comet.ml,它允许你查看所有超参数、指标、图表、依赖项等,包括实时指标。
在配置文件中添加你的API密钥:
例如: "comet_api_key": "your key here"
这是你开始训练后的效果:
<img align="center" 宽度="800" src="https://comet-ml.nyc3.digitaloceanspaces.com/CometDemo.gif">
你还可以将你的Github仓库链接到你的comet.ml项目以实现完整的版本控制。 这是一个显示此仓库示例的实时页面
配置
我使用Json作为配置方法,然后解析它,因此写下你想要的所有配置,然后使用"utils/config/process_config"解析它,并将此配置对象传递给所有其他对象。
主程序
这里是你将所有前面的部分组合在一起的地方。
- 解析配置文件。
- 创建一个tensorflow会话。
- 创建"Model"、"Data_Generator"和"Logger"的实例,并将配置传递给它们。
- 创建"Trainer"的实例,并将所有前面的对象传递给它。
- 现在你可以通过调用"Trainer.train()"训练你的模型。
未来工作
- 使用新的tensorflow数据集API替换数据加载器部分。
贡献
欢迎任何形式的增强或