Yet Another Lightning Hydra Template
Efficient workflow and reproducibility are extremely important components in every machine learning projects, which enable to:
- Rapidly iterate over new models and compare different approaches faster.
- Promote confidence in the results and transparency.
- Save time and resources.
PyTorch Lightning and Hydra serve as the foundation upon this template. Such reasonable technology stack for deep learning prototyping offers a comprehensive and seamless solution, allowing you to effortlessly explore different tasks across a variety of hardware accelerators such as CPUs, multi-GPUs, and TPUs. Furthermore, it includes a curated collection of best practices and extensive documentation for greater clarity and comprehension.
This template could be used as is for some basic tasks like Classification, Segmentation or Metric Learning, or be easily extended for any other tasks due to high-level modularity and scalable structure.
As a baseline I have used gorgeous Lightning Hydra Template, reshaped and polished it, and implemented more features which can improve overall efficiency of workflow and reproducibility.
Quick start
# clone template
git clone https://github.com/gorodnitskiy/yet-another-lightning-hydra-template
cd yet-another-lightning-hydra-template
# install requirements
pip install -r requirements.txt
Or run the project in docker. See more in Docker section.
Table of content
- Main technologies
- Project structure
- Workflow - how it works
- Hydra configs
- Logs
- Data
- Notebooks
- Hyperparameters search
- Docker
- Tests
- Continuous integration
Main technologies
PyTorch Lightning - a lightweight deep learning framework / PyTorch wrapper for professional AI researchers and machine learning engineers who need maximal flexibility without sacrificing performance at scale.
Hydra - a framework that simplifies configuring complex applications. The key feature is the ability to dynamically create a hierarchical configuration by composition and override it through config files and the command line.
Project structure
The structure of a machine learning project can vary depending on the specific requirements and goals of the project, as well as the tools and frameworks being used. However, here is a general outline of a common directory structure for a machine learning project:
src/
data/
logs/
tests/
- some additional directories, like:
notebooks/
,docs/
, etc.
In this particular case, the directory structure looks like:
├── configs <- Hydra configuration files
│ ├── callbacks <- Callbacks configs
│ ├── datamodule <- Datamodule configs
│ ├── debug <- Debugging configs
│ ├── experiment <- Experiment configs
│ ├── extras <- Extra utilities configs
│ ├── hparams_search <- Hyperparameter search configs
│ ├── hydra <- Hydra settings configs
│ ├── local <- Local configs
│ ├── logger <- Logger configs
│ ├── module <- Module configs
│ ├── paths <- Project paths configs
│ ├── trainer <- Trainer configs
│ │
│ ├── eval.yaml <- Main config for evaluation
│ └── train.yaml <- Main config for training
│
├── data <- Project data
├── logs <- Logs generated by hydra, lightning loggers, etc.
├── notebooks <- Jupyter notebooks.
├── scripts <- Shell scripts
│
├── src <- Source code
│ ├── callbacks <- Additional callbacks
│ ├── datamodules <- Lightning datamodules
│ ├── modules <- Lightning modules
│ ├── utils <- Utility scripts
│ │
│ ├── eval.py <- Run evaluation
│ └── train.py <- Run training
│
├── tests <- Tests of any kind
│
├── .dockerignore <- List of files ignored by docker
├── .gitattributes <- List of additional attributes to pathnames
├── .gitignore <- List of files ignored by git
├── .pre-commit-config.yaml <- Configuration of pre-commit hooks for code formatting
├── Dockerfile <- Dockerfile
├── Makefile <- Makefile with commands like `make train` or `make test`
├── pyproject.toml <- Configuration options for testing and linting
├── requirements.txt <- File for installing python dependencies
├── setup.py <- File for installing project as a package
└── README.md
Workflow - how it works
Before starting a project, you need to think about the following things to unsure in results reproducibility:
- Docker image setting up
- Freezing python package versions
- Code Version Control
- Data Version Control. Many of which currently provide not just Data Version Control, but a lot of side very useful features like Model Registry or Experiments Tracking:
- Experiments Tracking tools:
- Weights & Biases
- Neptune
- DVC
- Comet
- MLFlow
- TensorBoard
- Or just CSV files...
Basic workflow
This template could be used as is for some basic tasks like Classification, Segmentation or Metric Learning approach, but if you need to do something more complex, here it is a general workflow:
- Write your PyTorch Lightning DataModule (see examples in datamodules/datamodules.py)
- Write your PyTorch Lightning Module (see examples in modules/single_module.py)
- Fill up your configs, in particularly create experiment configs
- Run experiments:
- Run training with chosen experiment config:
python src/train.py experiment=experiment_name.yaml
- Use hyperparameter search, for example by Optuna Sweeper via Hydra:
# using Hydra multirun mode python src/train.py -m hparams_search=mnist_optuna
- Execute the runs with some config parameter manually:
python src/train.py -m logger=csv module.optimizer.weight_decay=0.,0.00001,0.0001
- Run evaluation with different checkpoints or prediction on custom dataset for additional analysis
The template contains example with MNIST
classification, which uses for tests by the way.
If you run python src/train.py
, you will get something like this:
Show terminal screen when running pipeline
LightningDataModule
At the start, you need to create PyTorch Dataset for you task. It has to include __getitem__
and __len__
methods.
Maybe you can use as is or easily modify already implemented Datasets in the template.
See more details in PyTorch documentation.
Also, it could be useful to see section about how it is possible to save data for training and evaluation.
Then, you need to create DataModule using PyTorch Lightning DataModule API. By default, API has the following methods:
prepare_data
(optional): perform data operations on CPU via a single process, like load and preprocess data, etc.setup
(optional): perform data operations on every GPU, like train/val/test splits, create datasets, etc.train_dataloader
: used to generate the training dataloader(s)val_dataloader
: used to generate the validation dataloader(s)test_dataloader
: used to generate the test dataloader(s)predict_dataloader
(optional): used to generate the prediction dataloader(s)
Show LightningDataModule API
from typing import Any, Dict, List, Optional, Union
from torch.utils.data import DataLoader, Dataset
from pytorch_lightning import LightningDataModule
class YourDataModule(LightningDataModule):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__()
self.train_set: Optional[Dataset] = None
self.valid_set: Optional[Dataset] = None
self.test_set: Optional[Dataset] = None
self.predict_set: Optional[Dataset] = None
...
def prepare_data(self) -> None:
# (Optional) Perform data operations on CPU via a single process
# - load data
# - preprocess data
# - etc.
...
def setup(self, stage: str) -> None:
# (Optional) Perform data operations on every GPU:
# - count number of classes
# - build vocabulary
# - perform train/val/test splits
# - create datasets
# - apply transforms (which defined explicitly in your datamodule)
# - etc.
if not self.train_set and not self.valid_set and not self.test_set:
self.train_set = ...
self.valid_set = ...
self.test_set = ...
if (stage == "predict") and not self.predict_set:
self.predict_set = ...
def train_dataloader(self) -> Union[DataLoader, List[DataLoader], Dict[str, DataLoader]]:
# Used to generate the training dataloader(s)
# This is the dataloader that the Trainer `fit()` method uses
return DataLoader(self.train_set, ...)
def val_dataloader(self) -> Union[DataLoader, List[DataLoader]]:
# Used to generate the validation dataloader(s)
# This is the dataloader that the Trainer `fit()` and `validate()` methods uses
return DataLoader(self.valid_set, ...)
def test_dataloader(self) -> Union[DataLoader, List[DataLoader]]:
# Used to generate the test dataloader(s)
# This is the dataloader that the Trainer `test()` method uses
return DataLoader(self.test_set, ...)
def predict_dataloader(self) -> Union[DataLoader, List[DataLoader]]:
# Used to generate the prediction dataloader(s)
# This is the dataloader that the Trainer `predict()` method uses
return DataLoader(self.predict_set, ...)
def teardown(self, stage: str) -> None:
# Used to clean-up when the run is finished
...
See examples of datamodule
configs in configs/datamodule folder.
By default, the template contains the following DataModules:
- SingleDataModule in which
train_dataloader
,val_dataloader
andtest_dataloader
return single DataLoader,predict_dataloader
returns list of DataLoaders - MultipleDataModule in which
train_dataloader
return dict of DataLoaders,val_dataloader
,test_dataloader
andpredict_dataloader
return list of DataLoaders
In the template, DataModules has _get_dataset_
method to simplify Datasets instantiation.
LightningModule
LightningModule API
Next, your need to create LightningModule using PyTorch Lightning LightningModule API. Minimum API has the following methods:
forward
: use for inference only (separate from training_step)training_step
: the complete training loopvalidation_step
: the complete validation looptest_step
: the complete test looppredict_step
: the complete prediction loopconfigure_optimizers
: define optimizers and LR schedulers
Also, you can override optional methods for each step to perform additional logic:
training_step_end
: training step end operationstraining_epoch_end
: training epoch end operationsvalidation_step_end
: validation step end operationsvalidation_epoch_end
: validation epoch end operationstest_step_end
: test step end operationstest_epoch_end
: test epoch end operations
Show LightningModule API methods and appropriate order
from typing import Any
from pytorch_lightning import LightningModule
class LitModel(LightningModule):
def __init__(self, *args: Any, **kwargs: Any):
super().__init__()
...
def forward(self, *args: Any, **kwargs: Any):
...
def training_step(self, *args: Any, **kwargs: Any):
...
def training_step_end(self, step_output: Any):
...
def training_epoch_end(self, outputs: Any):
...
def validation_step(self, *args: Any, **kwargs: Any):
...
def validation_step_end(self, step_output: Any):
...
def validation_epoch_end(self, outputs: Any):
...
def test_step(self, *args: Any, **kwargs: Any):
...
def test_step_end(self, step_output: Any):
...
def test_epoch_end(self, outputs: Any):
...
def configure_optimizers(self):
...
def any_extra_hook(self, *args: Any, **kwargs: Any):
...
In the template, LightningModule has model_step
method to adjust repeated operations, like forward
or loss
calculation, which are required in training_step
, validation_step
and test_step
.
Metrics
The template offers the following Metrics API
:
main
metric: main metric, which also uses for all callbacks or trackers likemodel_checkpoint
,early_stopping
orscheduler.monitor
.valid_best
metric: use for tracking the best validation metric. Usually it can beMaxMetric
orMinMetric
.additional
metrics: additional metrics.
Each metric config should contain _target_
key with metric class name and other parameters which are required by
metric. The template allows to use any metrics, for example from
torchmetrics or implemented by yourself (see examples in
modules/metrics/components/
or torchmetrics API).
See more details about implemented Metrics API and metrics
config as a part of
network
configs in configs/module/network folder.
Metric config example:
metrics:
main:
_target_: "torchmetrics.Accuracy"
task: "binary"
valid_best:
_target_: "torchmetrics.MaxMetric"
additional:
AUROC:
_target_: "torchmetrics.AUROC"
task: "binary"
Also, the template includes few manually implemented metrics:
Loss
The template offers the following Losses API
:
- Loss config should contain
_target_
key with loss class name and other parameters which are required by loss. - Parameter contains
weight
string in name will be wrapped bytorch.tensor
and cast totorch.float
type before passing to loss due to requirements from most of the losses.
The template allows to use any losses, for example from
PyTorch or implemented by yourself (see examples in
modules/losses/components/
).
See more details about implemented Losses API and loss
config as a part of
network
configs in configs/module/network folder.
Loss config examples:
loss:
_target_: "torch.nn.CrossEntropyLoss"
loss:
_target_: "torch.nn.BCEWithLogitsLoss"
pos_weight: [0.25]
loss:
_target_: