NEW! TFDS API for Meta-Dataset
To accompany the presentation of the VTAB+MD paper at NeurIPS 2021's Datasets and Benchmarks track, we are releasing a TensorFlow Datasets-based implementation of Meta-Dataset's input pipeline which is compatible with both the original Meta-Dataset protocol (MD-v1) and the updated protocol designed for VTAB+MD (MD-v2). See the documentation page for more information and example code snippets.
Meta-Dataset
This repository contains accompanying code for the article introducing Meta-Dataset, arxiv.org/abs/1903.03096 and the follow-up paper that proposes the VTAB+MD merged benchmark arxiv.org/abs/2104.02638. It also contains accompanying code and checkpoints for the CrossTransformers https://arxiv.org/abs/2007.11498 and FLUTE https://arxiv.org/abs/2105.07029 follow-up works, which improve performance.
This code is provided here in order to give more details on the implementation of the data-providing pipeline, our back-bones and models, as well as the experimental setting.
See below for user instructions, including how to:
- install the software,
- download and convert the data, and
- train implemented models.
See this introduction notebook for a demonstration of how to sample data from the pipeline (episodes or batches).
In order to run the experiments described in the first version of the arXiv article, arxiv.org/abs/1903.03096v1, please use the instructions, code, and configuration files at version arxiv_v1 of this repository.
We are currently working on updating the instructions, code, and configuration files to reproduce the results in the second version of the article, arxiv.org/abs/1903.03096v2. You can follow the progess in branch arxiv_v2_dev of this repository.
This is not an officially supported Google product.
Meta-Dataset: A Dataset of Datasets for Learning to Learn from Few Examples
Eleni Triantafillou, Tyler Zhu, Vincent Dumoulin, Pascal Lamblin, Utku Evci, Kelvin Xu, Ross Goroshin, Carles Gelada, Kevin Swersky, Pierre-Antoine Manzagol, Hugo Larochelle
Few-shot classification refers to learning a classifier for new classes given only a few examples. While a plethora of models have emerged to tackle it, we find the procedure and datasets that are used to assess their progress lacking. To address this limitation, we propose Meta-Dataset: a new benchmark for training and evaluating models that is large-scale, consists of diverse datasets, and presents more realistic tasks. We experiment with popular baselines and meta-learners on Meta-Dataset, along with a competitive method that we propose. We analyze performance as a function of various characteristics of test tasks and examine the models' ability to leverage diverse training sources for improving their generalization. We also propose a new set of baselines for quantifying the benefit of meta-learning in Meta-Dataset. Our extensive experimentation has uncovered important research challenges and we hope to inspire work in these directions.
CrossTransformers: spatially-aware few-shot transfer
Carl Doersch, Ankush Gupta, Andrew Zisserman
This is a Transformer-based neural network architecture which can find coarse spatial correspondence between the query and the support images, and then infer class membership by computing distances between spatially-corresponding features. The paper also introduces SimCLR episodes, which are episodes that require SimCLR-style instance recognition, and therefore encourage features which capture more than just the training-set categories. This algorithm is SOTA on Meta-Dataset (train-on-ILSVRC) as of NeurIPS 2020.
Configuration files for CrossTransformers with and without SimCLR episodes (CTX
and CTX+SimCLR Eps from the paper) can be found in
learn/gin/default/crosstransformer*
. We also have pretrained checkpoints for
these two configurations:
CTX,
and
CTX+SimCLR Eps,
as well as
CTX+SimCLR Eps+BOHB Aug.
Note that these were retrained from the versions reported in the paper, but
their performance should be on-par. The network structure is the same for all
three models, and so they can be loaded using either of the CrossTransformer
config files.
Learning a Universal Template for Few-shot Dataset Generalization (FLUTE)
_Eleni Triantafillou, Hugo Larochelle, Richard Zemel, Vincent Dumoulin
Few-shot Learning with a Universal TEmplate (FLUTE) is a model designed for the strong generalization challenge of few-shot learning classes from unseen datasets. At the time of publication (ICML 2021), it achieved SOTA on Meta-Dataset (train-on-all). It works by leveraging the training datasets to learn a 'universal template' that can be repurposed to solve diverse test tasks, by appropriately 'filling in the blanks' of the template each time, with an appropriate set of FiLM parameters that are learned with gradient descent in each test task.
Configuration files for training FLUTE, as well as the dataset classifier used
in FLUTE's Blender network can be found in learn/gin/default/flute.gin
and
learn/gin/default/flute_dataset_classifier.gin
, respectively. Configuration
files for testing different variants of FLUTE can be found in
learn/gin/best/flute*
The results reported in the paper were obtained with
learn/gin/best/flute.gin
.
The training script for FLUTE is train_flute.py
. We also have pre-trained
checkpoints for FLUTE and its Blender network: https://console.cloud.google.com/storage/gresearch/flute
Leaderboard (in progress)
The tables below were generated by this notebook.
Adding a new model to the leaderboard
- Gather accuracy results and 95% confidence intervals, as well as the number of episodes used for the CI (minimum 600).
- If you were affected by #54, make sure the evaluation on Traffic Sign is done on shuffled samples. We encourage you to re-train your best model (or at least perform validation again) as well.
- Create an issue, with the name of the model, results, as well as the article to cite or any other relevant information to include, and label it "leaderboard". Alternatively, submit a PR with an update to the notebook.
Training on ImageNet only
Method | Avg rank | ILSVRC (test) | Omniglot | Aircraft | Birds | Textures | QuickDraw | Fungi | VGG Flower | Traffic signs | MSCOCO |
---|---|---|---|---|---|---|---|---|---|---|---|
k-NN [[1]] | 14.6 | 41.03±1.01 (15) | 37.07±1.15 (16) | 46.81±0.89 (15) | 50.13±1.00 (15.5) | 66.36±0.75 (13) | 32.06±1.08 (16) | 36.16±1.02 (13) | 83.10±0.68 (12) | 44.59±1.19 (15) | 30.38±0.99 (15.5) |
Finetune [[1]] | 10.45 | 45.78±1.10 (13) | 60.85±1.58 (11.5) | 68.69±1.26 (5) | 57.31±1.26 (14) | 69.05±0.90 (9.5) | 42.60±1.17 (13.5) | 38.20±1.02 (11) | 85.51±0.68 (9) | 66.79±1.31 (5) | 34.86±0.97 (13) |
MatchingNet [[1]] | 13.55 | 45.00±1.10 (13) | 52.27±1.28 (14) | 48.97±0.93 (13) | 62.21±0.95 (12.5) | 64.15±0.85 (15) | 42.87±1.09 (13.5) | 33.97±1.00 (14) | 80.13±0.71 (15) | 47.80±1.14 (12.5) | 34.99±1.00 (13) |
ProtoNet [[1]] | 10.75 | 50.50±1.08 (10.5) | 59.98±1.35 (11.5) | 53.10±1.00 (10.5) | 68.79±1.01 (8.5) | 66.56±0.83 (13) | 48.96±1.08 (11) | 39.71±1.11 (9) | 85.27±0.77 (9) | 47.12±1.10 (14) | 41.00±1.10 (10.5) |
fo-MAML [[1]] | 12.25 | 45.51±1.11 (13) | 55.55±1.54 (13) | 56.24±1.11 (8.5) | 63.61±1.06 (12.5) | 68.04±0.81 (9.5) | 43.96±1.29 (13.5) | 32.10±1.10 (15) | 81.74±0.83 (14) | 50.93±1.51 (10.5) | 35.30±1.23 (13) |
RelationNet [[1]] | 15.55 | 34.69±1.01 (16) | 45.35±1.36 (15) | 40.73±0.83 (16) | 49.51±1.05 (15.5) | 52.97±0.69 (16) | 43.30±1.08 (13.5) | 30.55±1.04 (16) | 68.76±0.83 (16) | 33.67±1.05 (16) | 29.15±1.01 (15.5) |
fo-Proto-MAML [[1]] | 9.25 | 49.53±1.05 (10.5) | 63.37±1.33 (8.5) | 55.95±0.99 (8.5) | 68.66±0.96 (8.5) | 66.49±0.83 (13) | 51.52±1.00 (9.5) | 39.96±1.14 (6.5) | 87.15±0.69 (6) | 48.83±1.09 (12.5) | 43.74±1.12 (9) |
ALFA+fo-Proto-MAML [[3]] | 7.1 | 52.80±1.11 (8.5) | 61.87±1.51 (8.5) | 63.43±1.10 (6) | 69.75±1.05 (6.5) | 70.78±0.88 (7) | 59.17±1.16 (5.5) | 41.49±1.17 (6.5) | 85.96±0.77 (9) | 60.78±1.29 (8) | 48.11±1.14 (5.5) |
ProtoNet (large) [[4]] | 7.25 | 53.69±1.07 (6) | 68.50±1.27 (5.5) | 58.04±0.96 (7) | 74.07±0.92 (4.5) | 68.76±0.77 (9.5) | 53.30±1.06 (8) | 40.73±1.15 (6.5) | 86.96±0.73 (6) | 58.11±1.05 (9) | 41.70±1.08 (10.5) |
CTX [[4]] | 2.75 | 62.76±0.99 (2.5) | 82.21±1.00 (2.5) | 79.49±0.89 (2.5) | 80.63±0.88 (3) | 75.57±0.64 (4) | 72.68±0.82 (2) | 51.58±1.11 (2.5) | 95.34±0.37 (2) | 82.65±0.76 (3) | 59.90±1.02 (3.5) |
BOHB [[5]] | 7.85 | 51.92±1.05 (8.5) | 67.57±1.21 (5.5) | 54.12±0.90 (10.5) | 70.69±0.90 (6.5) | 68.34±0.76 (9.5) | 50.33±1.04 (9.5) | 41.38±1.12 (6.5) | 87.34±0.59 (6) | 51.80±1.04 (10.5) | 48.03±0.99 (5.5) |
SimpleCNAPS [[14],[7]] | 8.75 | 54.80±1.20 (6) | 62.00±1.30 (8.5) | 49.20±0.90 (13) | 66.50±1.00 (10.5) | 71.60±0.70 (5.5) | 56.60±1.00 (7) | 37.50±1.20 (11) | 82.10±0.90 (12) | 63.10±1.10 (6.5) | 45.80±1.00 (7.5) |
TransductiveCNAPS [[14],[8]] | 8.6 | 54.10±1.10 (6) | 62.90±1.30 (8.5) | 48.40±0.90 (13) | 67.30±0.90 (10.5) | 72.50±0.70 (5.5) | 58.00±1.00 (5.5) | 37.70±1.10 (11) | 82.80±0.80 (12) | 61.80±1.10 (6.5) | 45.80±1.00 (7.5) |
TSA_resnet18 [[12]] | 3.8 | 59.50±1.10 (4) | 78.20±1.20 (4) | 72.20±1.00 (4) | 74.90±0.90 (4.5) | 77.30±0.70 (3) | 67.60±0.90 (4) | 44.70±1.00 (4) | 90.90±0.60 (4) | 82.50±0.80 (3) | 59.00±1.00 (3.5) |
TSA_resnet34 [[12]] | 2.5 |