OML is a PyTorch-based framework to train and validate the models producing high-quality embeddings.
Trusted by
There is a number of people from Oxford and HSE universities who have used OML in their theses. [1] [2] [3]
OML 3.0 has been released!
The update focuses on several components:
-
We added "official" texts support and the corresponding Python examples. (Note, texts support in Pipelines is not supported yet.)
-
We introduced the
RetrievalResults
(RR
) class — a container to store gallery items retrieved for given queries.RR
provides a unified way to visualize predictions and compute metrics (if ground truths are known). It also simplifies post-processing, where anRR
object is taken as input and anotherRR_upd
is produced as output. Having these two objects allows comparison retrieval results visually or by metrics. Moreover, you can easily create a chain of such post-processors.RR
is memory optimized because of using batching: in other words, it doesn't store full matrix of query-gallery distances. (It doesn't make search approximate though).
-
We made
Model
andDataset
the only classes responsible for processing modality-specific logic.Model
is responsible for interpreting its input dimensions: for example,BxCxHxW
for images orBxLxD
for sequences like texts.Dataset
is responsible for preparing an item: it may useTransforms
for images orTokenizer
for texts. Functions computing metrics likecalc_retrieval_metrics_rr
,RetrievalResults
,PairwiseReranker
, and other classes and functions are unified to work with any modality.- We added
IVisualizableDataset
having method.visaulize()
that shows a single item. If implemented,RetrievalResults
is able to show the layout of retrieved results.
- We added
Migration from OML 2.* [Python API]:
The easiest way to catch up with changes is to re-read the examples!
-
The recommended way of validation is to use
RetrievalResults
and functions likecalc_retrieval_metrics_rr
,calc_fnmr_at_fmr_rr
, and others. TheEmbeddingMetrics
class is kept for use with PyTorch Lightning and inside Pipelines. Note, the signatures ofEmbeddingMetrics
methods have been slightly changed, see Lightning examples for that. -
Since modality-specific logic is confined to
Dataset
, it doesn't outputPATHS_KEY
,X1_KEY
,X2_KEY
,Y1_KEY
, andY2_KEY
anymore. Keys which are not modality-specific likeLABELS_KEY
,IS_GALLERY
,IS_QUERY_KEY
,CATEGORIES_KEY
are still in use. -
inference_on_images
is nowinference
and works with any modality. -
Slightly changed interfaces of
Datasets.
For example, we haveIQueryGalleryDataset
andIQueryGalleryLabeledDataset
interfaces. The first has to be used for inference, the second one for validation. Also addedIVisualizableDataset
interface. -
Removed some internals like
IMetricDDP
,EmbeddingMetricsDDP
,calc_distance_matrix
,calc_gt_mask
,calc_mask_to_ignore
,apply_mask_to_ignore
. These changes shouldn't affect you. Also removed code related to a pipeline with precomputed triplets.
Migration from OML 2.* [Pipelines]:
-
Feature extraction: No changes, except for adding an optional argument —
mode_for_checkpointing = (min | max)
. It may be useful to switch between the lower, the better and the greater, the better type of metrics. -
Pairwise-postprocessing pipeline: Slightly changed the name and arguments of the
postprocessor
sub config —pairwise_images
is nowpairwise_reranker
and doesn't need transforms.
Documentation
FAQ
Why do I need OML?
You may think "If I need image embeddings I can simply train a vanilla classifier and take its penultimate layer". Well, it makes sense as a starting point. But there are several possible drawbacks:
-
If you want to use embeddings to perform searching you need to calculate some distance among them (for example, cosine or L2). Usually, you don't directly optimize these distances during the training in the classification setup. So, you can only hope that final embeddings will have the desired properties.
-
The second problem is the validation process. In the searching setup, you usually care how related your top-N outputs are to the query. The natural way to evaluate the model is to simulate searching requests to the reference set and apply one of the retrieval metrics. So, there is no guarantee that classification accuracy will correlate with these metrics.
-
Finally, you may want to implement a metric learning pipeline by yourself. There is a lot of work: to use triplet loss you need to form batches in a specific way, implement different kinds of triplets mining, tracking distances, etc. For the validation, you also need to implement retrieval metrics, which include effective embeddings accumulation during the epoch, covering corner cases, etc. It's even harder if you have several gpus and use DDP. You may also want to visualize your search requests by highlighting good and bad search results. Instead of doing it by yourself, you can simply use OML for your purposes.
What is the difference between Open Metric Learning and PyTorch Metric Learning?
PML is the popular library for Metric Learning, and it includes a rich collection of losses, miners, distances, and reducers; that is why we provide straightforward examples of using them with OML. Initially, we tried to use PML, but in the end, we came up with our library, which is more pipeline / recipes oriented. That is how OML differs from PML:
-
OML has Pipelines which allows training models by preparing a config and your data in the required format (it's like converting data into COCO format to train a detector from mmdetection).
-
OML focuses on end-to-end pipelines and practical use cases. It has config based examples on popular benchmarks close to real life (like photos of products of thousands ids). We found some good combinations of hyperparameters on these datasets, trained and published models and their configs. Thus, it makes OML more recipes oriented than PML, and its author confirms this saying that his library is a set of tools rather the recipes, moreover, the examples in PML are mostly for CIFAR and MNIST datasets.
-
OML has the Zoo of pretrained models that can be easily accessed from the code in the same way as in
torchvision
(when you typeresnet50(pretrained=True)
). -
OML is integrated with PyTorch Lightning, so, we can use the power of its Trainer. This is especially helpful when we work with DDP, so, you compare our DDP example and the PMLs one. By the way, PML also has Trainers, but it's not widely used in the examples and custom
train
/test
functions are used instead.
We believe that having Pipelines, laconic examples, and Zoo of pretrained models sets the entry threshold to a really low value.
What is Metric Learning?
Metric Learning problem (also known as extreme classification problem) means a situation in which we have thousands of ids of some entities, but only a few samples for every entity. Often we assume that during the test stage (or production) we will deal with unseen entities which makes it impossible to apply the vanilla classification pipeline directly. In many cases obtained embeddings are used to perform search or matching procedures over them.
Here are a few examples of such tasks from the computer vision sphere:
- Person/Animal Re-Identification
- Face Recognition
- Landmark Recognition
- Searching engines for online shops and many others.
Glossary (Naming convention)
embedding
- model's output (also known asfeatures vector
ordescriptor
).query
- a sample which is used as a request in the retrieval procedure.gallery set
- the set of entities to search items similar toquery
(also known asreference
orindex
).Sampler
- an argument forDataLoader
which is used to form batchesMiner
- the object to form pairs or triplets after the batch was formed bySampler
. It's not necessary to form the combinations of samples only inside the current batch, thus, the memory bank may be a part ofMiner
.Samples
/Labels
/Instances
- as an example let's consider DeepFashion dataset. It includes thousands of fashion item ids (we name themlabels
) and several photos for each item id (we name the individual photo asinstance
orsample
). All of the fashion item ids have their groups like "skirts", "jackets", "shorts" and so on (we name themcategories
). Note, we avoid using the termclass
to avoid misunderstanding.training epoch
- batch samplers which we use for combination-based losses usually have a length equal to[number of labels in training dataset] / [numbers of labels in one batch]
. It means that we don't observe all of the available training samples in one epoch (as opposed to vanilla classification), instead, we observe all of the available labels.
How good may be a model trained with OML?
It may be comparable with the current (2022 year) SotA methods, for example, Hyp-ViT. (Few words about this approach: it's a ViT architecture trained with contrastive loss, but the embeddings were projected into some hyperbolic space. As the authors claimed, such a space is able to describe the nested structure of real-world data. So, the paper requires some heavy math to adapt the usual operations for the hyperbolical space.)
We trained the same architecture with triplet loss, fixing the rest of the parameters: training and test transformations, image size, and optimizer. See configs in Models Zoo. The trick was in heuristics in our miner and sampler:
-
Category Balance Sampler forms the batches limiting the number of categories C in it. For instance, when C = 1 it puts only jackets in one batch and only jeans into another one (just an example). It automatically makes the negative pairs harder: it's more meaningful for a model to realise why two jackets are different than to understand the same about a jacket and a t-shirt.
-
Hard Triplets Miner makes the task even harder keeping only the hardest triplets (with maximal positive and minimal negative distances).
Here are CMC@1 scores for 2 popular benchmarks. SOP dataset: Hyp-ViT — 85.9, ours — 86.6. DeepFashion dataset: Hyp-ViT — 92.5, ours — 92.1. Thus, utilising simple heuristics and avoiding heavy math we are able to perform on SotA level.
What about Self-Supervised Learning?
Recent research