• Stars
    star
    759
  • Rank 59,846 (Top 2 %)
  • Language
    Python
  • License
    MIT License
  • Created over 8 years ago
  • Updated 3 months ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Code and models accompanying "Deep Predictive Coding Networks for Video Prediction and Unsupervised Learning"

prednet

Code and models accompanying Deep Predictive Coding Networks for Video Prediction and Unsupervised Learning by Bill Lotter, Gabriel Kreiman, and David Cox.

The PredNet is a deep recurrent convolutional neural network that is inspired by the neuroscience concept of predictive coding (Rao and Ballard, 1999; Friston, 2005). Check out example prediction videos here.

The architecture is implemented as a custom layer1 in Keras. Code and model data is compatible with Keras 2.0 and Python 2.7 and 3.6. The latest version has been tested on Keras 2.2.4 with Tensorflow 1.6. For previous versions of the code compatible with Keras 1.2.1, use fbcdc18. To convert old PredNet model files and weights for Keras 2.0 compatibility, see convert_model_to_keras2 in keras_utils.py.

KITTI Demo

Code is included for training the PredNet on the raw KITTI dataset. We include code for downloading and processing the data, as well as training and evaluating the model. The preprocessed data and can also be downloaded directly using download_data.sh and the trained weights by running download_models.sh. The model download will include the original weights trained for t+1 prediction, the fine-tuned weights trained to extrapolate predictions for multiple timesteps, and the "Lall" weights trained with an 0.1 loss weight on upper layers (see paper for details).

Steps

  1. Download/process data

    python process_kitti.py

    This will scrape the KITTI website to download the raw data from the city, residential, and road categories (~165 GB) and then process the images (cropping, downsampling). Alternatively, the processed data (~3 GB) can be directly downloaded by executing download_data.sh

  2. Train model

    python kitti_train.py

    This will train a PredNet model for t+1 prediction. See Keras FAQ on how to run using a GPU. To download pre-trained weights, run download_models.sh

  3. Evaluate model

    python kitti_evaluate.py

    This will output the mean-squared error for predictions as well as make plots comparing predictions to ground-truth.

Feature Extraction

Extracting the intermediate features for a given layer in the PredNet can be done using the appropriate output_mode argument. For example, to extract the hidden state of the LSTM (the "Representation" units) in the lowest layer, use output_mode = 'R0'. More details can be found in the PredNet docstring.

Multi-Step Prediction

The PredNet argument extrap_start_time can be used to force multi-step prediction. Starting at this time step, the prediction from the previous time step will be treated as the actual input. For example, if the model is run on a sequence of 15 timesteps with extrap_start_time = 10, the last output will correspond to a t+5 prediction. In the paper, we train in this setting starting from the original t+1 trained weights (see kitti_extrap_finetune.py), and the resulting fine-tuned weights are included in download_models.sh. Note that when training with extrapolation, the "errors" are no longer tied to ground truth, so the loss should be calculated on the pixel predictions themselves. This can be done by using output_mode = 'prediction', as illustrated in kitti_extrap_finetune.py.

Additional Notes

When training on a new dataset, the image size has to be divisible by 2^(nb of layers - 1) because of the cyclical 2x2 max-pooling and upsampling operations.


1 Note on implementation: PredNet inherits from the Recurrent layer class, i.e. it has an internal state and a step function. Given the top-down then bottom-up update sequence, it must currently be implemented in Keras as essentially a 'super' layer where all layers in the PredNet are in one PredNet 'layer'. This is less than ideal, but it seems like the most efficient way as of now. We welcome suggestions if anyone thinks of a better implementation.

More Repositories

1

eyetracker

Software for our self-calibrating eye tracker
C++
39
star
2

edn-cvpr2014

Python
22
star
3

ratslam-python

An (in-progress) Python port of the RatSLAM simultaneous localization and mapping algorithm
Python
22
star
4

ostrichinator

MATLAB
17
star
5

tsnet

Tensor Switching Networks
Python
12
star
6

perceptual-annotation

Public code and information about our Perceptual Annotation technique
Python
6
star
7

freenect-stuff

Our own idiosyncratic freenect noodling
C
5
star
8

camera-capture-thing

C++
5
star
9

spike-audio-unit

An audio unit for spike detection
Objective-C
4
star
10

coxlab_mwclient_plugins

Lab-specific plugins for the MW Client
Objective-C
3
star
11

coxlab_mwcore_plugins

Lab-specific plugins for the core components of the MW Application Suite
C++
3
star
12

boost_patched_for_objcplusplus

A minor fork of boost patched to work with ObjC++ (e.g. not using "id" as a variable name)
C++
2
star
13

logic_electrophys_templates

2
star
14

python-simple-dc1394

A very simple frame grabbing python binding for libdc1394
Python
2
star
15

simple-camera-recorder

Python
2
star
16

simple_eyetracker

Python
2
star
17

python-mclp

A python wrapper for Peter Gehler and Sebastien Nowozin's LPBoost implementation
C++
2
star
18

mw_data_analysis_helpers

Python
1
star
19

mw_test_experiments

Some test experiments
1
star
20

structured_light_stereotaxy

Structured light imaging of rat skulls
Python
1
star
21

kinematic_mount

Kinematic mount headpost
1
star
22

connectomics-sandbox

Python
1
star
23

spike-visualization

C
1
star
24

mworks_realtime_tools

Several scripts for viewing 'realtime' data from mworks
Python
1
star
25

cinder-spikes

A cinder-based spike waveform viewer
C++
1
star
26

mw_parameters

An experimental implementation of self-describing parameters for mw
C
1
star
27

prednet-similarity-analysis

Python
1
star
28

osirix-plugins

Objective-C
1
star
29

povray_blobs

original povray files from coxlab svn
POV-Ray SDL
1
star
30

mw_audio_spike_bridge

C++
1
star
31

physiology_analysis

pixel clock
Python
1
star
32

mw_bubbles_plugin

A stimulus plugin for creating "bubbles"-style gaussian masks
C++
1
star
33

Pumpr

pumpr fills syringes with water for behavior experiments so you don't have toâ„¢
Python
1
star
34

cnc_controller

A controller... for the CNC!
Python
1
star
35

audio-neurophys

Audio equipment for neurophysiology
Python
1
star
36

hyperopt_cvpr2012

Python
1
star
37

pixel-clock-audio-unit

C++
1
star
38

itracker

online eye-tracking for use with imaging experiments
Python
1
star
39

MotionGUI

This is the GUI that acts as a symbiotic twin to scan image v3.8.1 we use for data-acquisition. It implements mainly control of the xyz motor stages and other functionality that can be useful to have outside of scan image.
MATLAB
1
star