• Stars
    star
    242
  • Rank 167,048 (Top 4 %)
  • Language
    Python
  • License
    MIT License
  • Created about 7 years ago
  • Updated over 5 years ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Unofficial PyTorch implementation of DeepMind's PNAS 2017 paper "Overcoming Catastrophic Forgetting"

pytorch-ewc

Unofficial PyTorch implementation of DeepMind's paper Overcoming Catastrophic Forgetting, PNAS 2017.

graphic-image

Results

Continual Learning without EWC (left) and with EWC (right).

Installation

$ git clone https://github.com/kuc2477/pytorch-ewc && cd pytorch-ewc
$ pip install -r requirements.txt

CLI

Implementation CLI is provided by main.py

Usage

$ ./main.py --help
$ usage: EWC PyTorch Implementation [-h] [--hidden-size HIDDEN_SIZE]
                                  [--hidden-layer-num HIDDEN_LAYER_NUM]
                                  [--hidden-dropout-prob HIDDEN_DROPOUT_PROB]
                                  [--input-dropout-prob INPUT_DROPOUT_PROB]
                                  [--task-number TASK_NUMBER]
                                  [--epochs-per-task EPOCHS_PER_TASK]
                                  [--lamda LAMDA] [--lr LR]
                                  [--weight-decay WEIGHT_DECAY]
                                  [--batch-size BATCH_SIZE]
                                  [--test-size TEST_SIZE]
                                  [--fisher-estimation-sample-size FISHER_ESTIMATION_SAMPLE_SIZE]
                                  [--random-seed RANDOM_SEED] [--no-gpus]
                                  [--eval-log-interval EVAL_LOG_INTERVAL]
                                  [--loss-log-interval LOSS_LOG_INTERVAL]
                                  [--consolidate]

optional arguments:
  -h, --help            show this help message and exit
  --hidden-size HIDDEN_SIZE
  --hidden-layer-num HIDDEN_LAYER_NUM
  --hidden-dropout-prob HIDDEN_DROPOUT_PROB
  --input-dropout-prob INPUT_DROPOUT_PROB
  --task-number TASK_NUMBER
  --epochs-per-task EPOCHS_PER_TASK
  --lamda LAMDA
  --lr LR
  --weight-decay WEIGHT_DECAY
  --batch-size BATCH_SIZE
  --test-size TEST_SIZE
  --fisher-estimation-sample-size FISHER_ESTIMATION_SAMPLE_SIZE
  --random-seed RANDOM_SEED
  --no-gpus
  --eval-log-interval EVAL_LOG_INTERVAL
  --loss-log-interval LOSS_LOG_INTERVAL
  --consolidate

Train

$ python -m visdom.server &
$ ./main.py               # Train the network without consolidation.
$ ./main.py --consolidate # Train the network with consolidation.

Update Logs

  • 2019.06.29
    • Fixed a critical bug within model.estimate_fisher(): Squared gradients of log-likelihood w.r.t. each layer were mean-reduced over all the dimensions. Now it correctly estimates the Fisher matrix by averaging only over the batch dimension
  • 2019.03.22
    • Fixed a critical bug within model.estimate_fisher(): Fisher matrix were being estimated with squared expectation of gradient of log-likelihoods. Now it estimates the Fisher matrix with the expectation of squared gradient of log-likelihood.
    • Changed the default optimizer from Adam to SGD
    • Migrated the project to PyTorch 1.0.1 and visdom 0.1.8.8

Reference

Author

Ha Junsoo / @kuc2477 / MIT License

More Repositories

1

pytorch-deep-generative-replay

PyTorch implementation of "Continual Learning with Deep Generative Replay", NIPS 2017
Python
149
star
2

pytorch-vae

PyTorch implementation of "Auto-Encoding Variational Bayes", arxiv:1312.6114
Python
46
star
3

pytorch-wgan-gp

PyTorch implementation of "Improved Training of Wasserstein GANs", arxiv:1704.00028
Python
27
star
4

pytorch-splitnet

PyTorch implementation of ICML 2017 paper, SplitNet: Learning to Semantically Split Deep Networks for Parameter Reduction and Model Parallelization
Python
17
star
5

dl-papers

πŸ“ Deep Learning papers that enlightened me
Shell
12
star
6

pytorch-memn2n

PyTorch implementation of FAIR's paper "End-to-End Memory Network", NIPS 2015
Python
11
star
7

la-dynamics-in-games

Jupyter Notebook
8
star
8

backbone.csrf

Configure X-CSRFToken header for all Backbone sync requests
JavaScript
6
star
9

pytorch-ntm

PyTorch implementation of DeepMind's paper "Neural Turing Machines", arxiv:1410.5401
Python
4
star
10

dotfiles

πŸ’» My dotfiles for UNIX like systems
Shell
4
star
11

tensorflow-infogan

TensorFlow implementation of "InfoGAN: Interpretable Representation Learning by Information Maximizing Generative Adversarial Nets", NIPS 2016
Python
4
star
12

oracle

였늘 λ‹Ήμ‹ μ˜ 상사가 μΆœκ·Όν•  지, 신탁을 λ‚΄λ €λ“œλ¦¬κ² μŠ΅λ‹ˆλ‹€
Python
3
star
13

zarathustra-quotes

Inspiring quotes from Nietzsche's notorious 'Thus Spoke Zarathustra'
2
star
14

anchor-frontend-pc

anchor pc client
JavaScript
2
star
15

pytorch-wrn

PyTorch implementation of "Wide Residual Networks", BMVC 2016
Python
2
star
16

django-record

Create snapshot record for an instance when it has been changed either directly or indirectly
Python
2
star
17

news

πŸ“° Asynchronous web subscription engine
Python
2
star
18

dom-ged-genetic-approximator

DOM tree graph edit distance approximator implemented with genetic algorithm
Python
1
star
19

naver-unse

Haskell wrapper of unofficial Naver daily fortune telling API
Haskell
1
star
20

vim-guerilla

Minimal, lightweight cross platform vim configuration for guerilla devs
Vim Script
1
star