• Stars
    star
    130
  • Rank 277,575 (Top 6 %)
  • Language
    Jupyter Notebook
  • License
    MIT License
  • Created about 5 years ago
  • Updated about 5 years ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

An official PyTorch implementation of “Multimodal Model-Agnostic Meta-Learning via Task-Aware Modulation” (NeurIPS 2019) by Risto Vuorio*, Shao-Hua Sun*, Hexiang Hu, and Joseph J. Lim

Multimodal Model-Agnostic Meta-Learning for Few-shot Classification

This project is an implementation of Multimodal Model-Agnostic Meta-Learning via Task-Aware Modulation, which is published in NeurIPS 2019. Please visit our project page for more information and contact Shao-Hua Sun for any questions.

Model-agnostic meta-learners aim to acquire meta-prior parameters from a distribution of tasks and adapt to novel tasks with few gradient updates. Yet, seeking a common initialization shared across the entire task distribution substantially limits the diversity of the task distributions that they are able to learn from. We propose a multimodal MAML (MMAML) framework, which is able to modulate its meta-learned prior according to the identified mode, allowing more efficient fast adaptation. An illustration of the proposed framework is as follows.

We evaluate our model and baselines (MAML and Multi-MAML) on multiple multimodal settings based on the following five datasets: (a) Omniglot, (b) Mini-ImageNet, (c) FC100 (e.g. CIFAR100), (d) CUB-200-2011, and (e) FGVC-Aircraft.

Datasets

Run the following command to download and preprocess the datasets

python download.py --dataset aircraft bird cifar miniimagenet

Getting started

Please first install the following prerequisites: wget, unzip.

To avoid any conflict with your existing Python setup, and to keep this project self-contained, it is suggested to work in a virtual environment with virtualenv. To install virtualenv:

pip install --upgrade virtualenv

Create a virtual environment, activate it and install the requirements in requirements.txt.

virtualenv mmaml_venv
source mmaml_venv/bin/activate
pip install -r requirements.txt

Usage

After downloading the datasets, we can start to train models with the following commands.

Training command

$ python main.py -dataset multimodal_few_shot --multimodal_few_shot omniglot miniimagenet cifar bird aircraft --mmaml-model True --num-batches 600000 --output-folder mmaml_5mode_5w1s
  • Selected arguments (see the trainer.py for more details)
    • --output-folder: a nickname for the training
    • --dataset: choose among omniglot, miniimagenet, cifar, bird (CUB), and aircraft. You can also add your own datasets.
    • Checkpoints: specify the path to a pre-trained checkpoint
      • --checkpoint: load all the parameters (e.g. train_dir/mmaml_5mode_5w1s/maml_gatedconv_60000.pt).
    • Hyperparameters
      • --num-batches: number of batches
      • --meta-batch-size: number of tasks per batch
      • --slow-lr: learning rate for the global update of MAML
      • --fast-lr: learning rate for the adapted models
      • --num-updates: how many update steps in the inner loop
      • --num-classes-per-batch: how many classes per task (N-way)
      • --num-samples-per-class: how many samples per class for training (K-shot)
      • --num-val-samples: how many samples per class for validation
      • --max_steps: the max training iterations
    • Logging
      • --log-interval: number of batches between tensorboard writes
      • --save-interval: number of batches between model saves
    • Model
      • maml-model: set to True to train a MAML model
      • mmaml-model: set to True to train a MMAML (our) model

Interpret TensorBoard

Launch Tensorboard and go to the specified port, you can see differernt accuracies and losses in the scalars tab.

You can reproduce our results with the following training commands.

2 Modes (Omniglot and Mini-ImageNet)

Setup Method Command
5w1s MAML python main.py --dataset multimodal_few_shot --multimodal_few_shot omniglot miniimagenet --maml-model True --num-batches 600000 --output-folder maml_2mode_5w1s
5w1s Ours python main.py --dataset multimodal_few_shot --multimodal_few_shot omniglot miniimagenet --mmaml-model True --num-batches 600000 --output-folder mmaml_2mode_5w1s
5w5s MAML python main.py --dataset multimodal_few_shot --multimodal_few_shot omniglot miniimagenet --maml-model True --num-batches 600000 --num-samples-per-class 5 --output-folder maml_2mode_5w5s
5w5s Ours python main.py --dataset multimodal_few_shot --multimodal_few_shot omniglot miniimagenet --mmaml-model True --num-batches 600000 --num-samples-per-class 5 --output-folder mmaml_2mode_5w5s
20w1s MAML python main.py --dataset multimodal_few_shot --multimodal_few_shot omniglot miniimagenet --maml-model True --num-batches 400000 --meta-batch-size 5 --num-classes-per-batch 20 --output-folder maml_2mode_20w1s
20w1s Ours python main.py --dataset multimodal_few_shot --multimodal_few_shot omniglot miniimagenet --mmaml-model True --num-batches 400000 --meta-batch-size 5 --num-classes-per-batch 20 --output-folder mmaml_2mode_20w1s

3 Modes (Omniglot, Mini-ImageNet, and FC100)

Setup Method Command
5w1s MAML python main.py --dataset multimodal_few_shot --multimodal_few_shot omniglot miniimagenet cifar --maml-model True --num-batches 600000 --output-folder maml_3mode_5w1s
5w1s Ours python main.py --dataset multimodal_few_shot --multimodal_few_shot omniglot miniimagenet cifar --mmaml-model True --num-batches 600000 --output-folder mmaml_3mode_5w1s
5w5s MAML python main.py --dataset multimodal_few_shot --multimodal_few_shot omniglot miniimagenet cifar --maml-model True --num-batches 600000 --num-samples-per-class 5 --output-folder maml_5mode_5w5s
5w5s Ours python main.py --dataset multimodal_few_shot --multimodal_few_shot omniglot miniimagenet cifar --mmaml-model True --num-batches 600000 --num-samples-per-class 5 --output-folder mmaml_5mode_5w5s
20w1s MAML python main.py --dataset multimodal_few_shot --multimodal_few_shot omniglot miniimagenet cifar --maml-model True --num-batches 400000 --meta-batch-size 5 --num-classes-per-batch 20 --output-folder maml_3mode_20w1s
20w1s Ours python main.py --dataset multimodal_few_shot --multimodal_few_shot omniglot miniimagenet cifar --mmaml-model True --num-batches 400000 --meta-batch-size 5 --num-classes-per-batch 20 --output-folder mmaml_3mode_20w1s

5 Modes (Omniglot, Mini-ImageNet, FC100, Aircraft, and CUB)

Setup Method Command
5w1s MAML python main.py --dataset multimodal_few_shot --maml-model True --num-batches 600000 --output-folder maml_5mode_5w1s
5w1s MAML python main.py --dataset multimodal_few_shot --multimodal_few_shot omniglot miniimagenet cifar bird aircraft --maml-model True --num-batches 600000 --output-folder maml_5mode_5w1s
5w1s Ours python main.py --dataset multimodal_few_shot --mmaml-model True --num-batches 600000 --output-folder mmaml_5mode_5w1s
5w1s Ours python main.py --dataset multimodal_few_shot --multimodal_few_shot omniglot miniimagenet cifar bird aircraft --mmaml-model True --num-batches 600000 --output-folder mmaml_5mode_5w1s
5w5s MAML python main.py --dataset multimodal_few_shot --maml-model True --num-batches 600000 --num-samples-per-class 5 --output-folder maml_5mode_5w5s
5w5s MAML python main.py --dataset multimodal_few_shot --multimodal_few_shot omniglot miniimagenet cifar bird aircraft --maml-model True --num-batches 600000 --num-samples-per-class 5 --output-folder maml_5mode_5w5s
5w5s Ours python main.py --dataset multimodal_few_shot --mmaml-model True --num-batches 600000 --num-samples-per-class 5 --output-folder mmaml_5mode_5w5s
5w5s Ours python main.py --dataset multimodal_few_shot --multimodal_few_shot omniglot miniimagenet cifar bird aircraft --mmaml-model True --num-batches 600000 --num-samples-per-class 5 --output-folder mmaml_5mode_5w5s
20w1s MAML python main.py --dataset multimodal_few_shot --maml-model True --num-batches 400000 --meta-batch-size 5 --num-classes-per-batch 20 --output-folder maml_5mode_20w1s
20w1s MAML python main.py --dataset multimodal_few_shot --multimodal_few_shot omniglot miniimagenet cifar bird aircraft --maml-model True --num-batches 400000 --meta-batch-size 5 --num-classes-per-batch 20 --output-folder maml_5mode_20w1s
20w1s Ours python main.py --dataset multimodal_few_shot --mmaml-model True --num-batches 400000 --meta-batch-size 5 --num-classes-per-batch 20 --output-folder mmaml_5mode_20w1s
20w1s Ours python main.py --dataset multimodal_few_shot --multimodal_few_shot omniglot miniimagenet cifar bird aircraft --mmaml-model True --num-batches 400000 --meta-batch-size 5 --num-classes-per-batch 20 --output-folder mmaml_5mode_20w1s

Multi-MAML

Setup Dataset Command
5w1s Omniglot python main.py --dataset multimodal_few_shot --multimodal_few_shot omniglot --maml-model True --fast-lr 0.4 --num-update 1 --num-batches 600000 --output-folder multi_omniglot_5w1s
5w1s Mini-ImageNet python main.py --dataset multimodal_few_shot --multimodal_few_shot miniimagenet --maml-model True --fast-lr 0.01 --meta-batch-size 4 --num-batches 320000 --output-folder multi_miniimagenet_5w1s
5w1s FC100 python main.py --dataset multimodal_few_shot --multimodal_few_shot cifar --maml-model True --fast-lr 0.01 --meta-batch-size 4 --num-batches 320000 --output-folder multi_cifar_5w1s
5w1s Bird python main.py --dataset multimodal_few_shot --multimodal_few_shot bird --maml-model True --fast-lr 0.01 --meta-batch-size 4 --num-batches 320000 --output-folder multi_bird_5w1s
5w1s Aircraft python main.py --dataset multimodal_few_shot --multimodal_few_shot aircraft --maml-model True --fast-lr 0.01 --meta-batch-size 4 --num-batches 320000 --output-folder multi_aircraft_5w1s
5w5s Omniglot python main.py --dataset multimodal_few_shot --multimodal_few_shot omniglot --maml-model True --fast-lr 0.4 --num-update 1 --num-batches 600000 --num-samples-per-class 5 --output-folder multi_omniglot_5w5s
5w5s Mini-ImageNet python main.py --dataset multimodal_few_shot --multimodal_few_shot miniimagenet --maml-model True --fast-lr 0.01 --meta-batch-size 4 --num-batches 320000 --num-samples-per-class 5 --output-folder multi_miniimagenet_5w5s
5w5s FC100 python main.py --dataset multimodal_few_shot --multimodal_few_shot cifar --maml-model True --fast-lr 0.01 --meta-batch-size 4 --num-batches 320000 --num-samples-per-class 5 --output-folder multi_cifar_5w5s
5w5s Bird python main.py --dataset multimodal_few_shot --multimodal_few_shot bird --maml-model True --fast-lr 0.01 --meta-batch-size 4 --num-batches 320000 --num-samples-per-class 5 --output-folder multi_bird_5w5s
5w5s Aircraft python main.py --dataset multimodal_few_shot --multimodal_few_shot aircraft --maml-model True --fast-lr 0.01 --meta-batch-size 4 --num-batches 320000 --num-samples-per-class 5 --output-folder multi_aircraft_5w5s
20w1s Omniglot python main.py --dataset multimodal_few_shot --multimodal_few_shot omniglot --maml-model True --fast-lr 0.1 --meta-batch-size 4 --num-batches 320000 --num-classes-per-batch 20 --output-folder multi_omniglot_20w1s
20w1s Mini-ImageNet python main.py --dataset multimodal_few_shot --multimodal_few_shot miniimagenet --maml-model True --fast-lr 0.01 --meta-batch-size 4 --num-batches 320000 --num-classes-per-batch 20 --output-folder multi_miniimagenet_20w1s
20w1s FC100 python main.py --dataset multimodal_few_shot --multimodal_few_shot cifar --maml-model True --fast-lr 0.01 --meta-batch-size 4 --num-batches 320000 --num-classes-per-batch 20 --output-folder multi_cifar_20w1s
20w1s Bird python main.py --dataset multimodal_few_shot --multimodal_few_shot bird --maml-model True --fast-lr 0.01 --meta-batch-size 4 --num-batches 320000 --num-classes-per-batch 20 --output-folder multi_bird_20w1s
20w1s Aircraft python main.py --dataset multimodal_few_shot --multimodal_few_shot aircraft --maml-model True --fast-lr 0.01 --meta-batch-size 4 --num-batches 320000 --num-classes-per-batch 20 --output-folder multi_aircraft_20w1s

Results

2 Modes (Omniglot and Mini-ImageNet)

Method 5-way 1-shot 5-way 5-shot 20-way 1-shot
MAML 66.80% 77.79% 44.69%
Multi-MAML 66.85% 73.07% 53.15%
MMAML (Ours) 69.93% 78.73% 47.80%

3 Modes (Omniglot, Mini-ImageNet, and FC100)

Method 5-way 1-shot 5-way 5-shot 20-way 1-shot
MAML 54.55% 67.97% 28.22%
Multi-MAML 55.90% 62.20% 39.77%
MMAML (Ours) 57.47% 70.15% 36.27%

5 Modes (Omniglot, Mini-ImageNet, FC100, Aircraft, and CUB)

Method 5-way 1-shot 5-way 5-shot 20-way 1-shot
MAML 44.09% 54.41% 28.85%
Multi-MAML 45.46% 55.92% 33.78%
MMAML (Ours) 49.06% 60.83% 33.97%

Please check out our paper for more comprehensive results.

Related work

Cite the paper

If you find this useful, please cite

@inproceedings{vuorio2019multimodal,
  title={Multimodal Model-Agnostic Meta-Learning via Task-Aware Modulation},
  author={Vuorio, Risto and Sun, Shao-Hua and Hu, Hexiang and Lim, Joseph J.},
  booktitle={Neural Information Processing Systems},
  year={2019},
}

Authors

Shao-Hua Sun, Risto Vuorio, Hexiang Hu

More Repositories

1

ICLR2020-OpenReviewData

Script that crawls meta data from ICLR OpenReview webpage. Tutorials on installing and using Selenium and ChromeDriver on Ubuntu.
Jupyter Notebook
453
star
2

ICLR2019-OpenReviewData

Script that crawls meta data from ICLR OpenReview webpage. Tutorials on installing and using Selenium and ChromeDriver on Ubuntu.
Jupyter Notebook
390
star
3

Activation-Visualization-Histogram

Compare SELUs (scaled exponential linear units) with other activations on MNIST, CIFAR10, etc.
Python
381
star
4

Group-Normalization-Tensorflow

A TensorFlow implementation of Group Normalization on the task of image classification
Python
208
star
5

Multiview2Novelview

An official TensorFlow implementation of "Multi-view to Novel view: Synthesizing novel views with Self-Learned Confidence" (ECCV 2018) by Shao-Hua Sun, Minyoung Huh, Yuan-Hong Liao, Ning Zhang, and Joseph J. Lim
Python
199
star
6

awesome-program

A curated list of papers related to program synthesis, program induction, program execution, program and code repair, and programmatic reinforcement learning.
130
star
7

VAE-Tensorflow

A Tensorflow implementation of a Variational Autoencoder for the deep learning course at the University of Southern California (USC).
Jupyter Notebook
126
star
8

demo2program

An official TensorFlow implementation of "Neural Program Synthesis from Diverse Demonstration Videos" (ICML 2018) by Shao-Hua Sun, Hyeonwoo Noh, Sriram Somasundaram, and Joseph J. Lim
Python
102
star
9

MultiDigitMNIST

Combine multiple MNIST digits to create datasets with 100/1000 classes for few-shot learning/meta-learning
Python
80
star
10

DCGAN-Tensorflow

A Tensorflow implementation of Deep Convolutional Generative Adversarial Networks trained on Fashion-MNIST, CIFAR-10, etc.
Python
71
star
11

NovelViewSynthesis-TensorFlow

A TensorFlow implementation of a simple Novel View Synthesis model on ShapeNet (cars and chairs), KITTI, and Synthia.
Python
47
star
12

WGAN-GP-TensorFlow

TensorFlow implementations of Wasserstein GAN with Gradient Penalty (WGAN-GP), Least Squares GAN (LSGAN), GANs with the hinge loss.
Python
44
star