• Stars
    star
    1,082
  • Rank 42,777 (Top 0.9 %)
  • Language
    Python
  • License
    MIT License
  • Created about 7 years ago
  • Updated almost 4 years ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Code for the NeurIPS 2017 Paper "Prototypical Networks for Few-shot Learning"

Prototypical Networks for Few-shot Learning

Code for the NIPS 2017 paper Prototypical Networks for Few-shot Learning.

If you use this code, please cite our paper:

@inproceedings{snell2017prototypical,
  title={Prototypical Networks for Few-shot Learning},
  author={Snell, Jake and Swersky, Kevin and Zemel, Richard},
  booktitle={Advances in Neural Information Processing Systems},
  year={2017}
 }

Training a prototypical network

Install dependencies

  • This code has been tested on Ubuntu 16.04 with Python 3.6 and PyTorch 0.4.
  • Install PyTorch and torchvision.
  • Install torchnet by running pip install git+https://github.com/pytorch/tnt.git@master.
  • Install the protonets package by running python setup.py install or python setup.py develop.

Set up the Omniglot dataset

  • Run sh download_omniglot.sh.

Train the model

  • Run python scripts/train/few_shot/run_train.py. This will run training and place the results into results.
    • You can specify a different output directory by passing in the option --log.exp_dir EXP_DIR, where EXP_DIR is your desired output directory.
    • If you are running on a GPU you can pass in the option --data.cuda.
  • Re-run in trainval mode python scripts/train/few_shot/run_trainval.py. This will save your model into results/trainval by default.

Evaluate

  • Run evaluation as: python scripts/predict/few_shot/run_eval.py --model.model_path results/trainval/best_model.pt.