• Stars
    star
    241
  • Rank 167,643 (Top 4 %)
  • Language
    Python
  • License
    MIT License
  • Created almost 6 years ago
  • Updated about 1 year ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Code for "Unsupervised State Representation Learning in Atari"

Unsupervised State Representation Learning in Atari

Ankesh Anand*, Evan Racah*, Sherjil Ozair*, Yoshua Bengio, Marc-Alexandre Cรดtรฉ, R Devon Hjelm

This repo provides code for the benchmark and techniques introduced in the paper Unsupervised State Representation Learning in Atari

Install

AtariARI Wrapper

You can do a minimal install to get just the AtariARI (Atari Annotated RAM Interface) wrapper by doing:

pip install 'gym[atari]'
pip install git+git://github.com/mila-iqia/atari-representation-learning.git

This just requires gym[atari] and it gives you the ability to play around with the AtariARI wrapper. If you want to use the code for training representation learning methods and probing them, you will need a full installation:

Full installation (AtariARI Wrapper + Training & Probing Code)

# PyTorch and scikit learn
conda install pytorch torchvision -c pytorch
conda install scikit-learn

# Baselines for Atari preprocessing
# Tensorflow is a dependency, but you don't need to install the GPU version
conda install tensorflow
pip install git+git://github.com/openai/baselines

# pytorch-a2c-ppo-acktr for RL utils
pip install git+git://github.com/ankeshanand/pytorch-a2c-ppo-acktr-gail

# Clone and install our package
pip install -r requirements.txt
pip install git+git://github.com/mila-iqia/atari-representation-learning.git

Usage

Atari Annotated RAM Interface (AtariARI):

AtariARI exposes the ground truth labels for different state variables for each observation. We have made AtariARI available as a Gym wrapper, to use it simply wrap an Atari gym env with AtariARIWrapper.

import gym
from atariari.benchmark.wrapper import AtariARIWrapper
env = AtariARIWrapper(gym.make('MsPacmanNoFrameskip-v4'))
obs = env.reset()
obs, reward, done, info = env.step(1)

Now, info is a dictionary of the form:

{'ale.lives': 3,
 'labels': {'enemy_sue_x': 88,
  'enemy_inky_x': 88,
  'enemy_pinky_x': 88,
  'enemy_blinky_x': 88,
  'enemy_sue_y': 80,
  'enemy_inky_y': 80,
  'enemy_pinky_y': 80,
  'enemy_blinky_y': 50,
  'player_x': 88,
  'player_y': 98,
  'fruit_x': 0,
  'fruit_y': 0,
  'ghosts_count': 3,
  'player_direction': 3,
  'dots_eaten_count': 0,
  'player_score': 0,
  'num_lives': 2}}

Note: In our experiments, we use additional preprocessing for Atari environments mainly following Minh et. al, 2014. See atariari/benchmark/envs.py for more info!

If you want the raw RAM annotations (which parts of ram correspond to each state variable), check out atariari/benchmark/ram_annotations.py

Probing


โš ๏ธ Important โš ๏ธ: The RAM labels are meant for full-sized Atari observations (210 * 160). Probing results won't be accurate if you downsample the observations.

We provide an interface for the included probing tasks.

First, get episodes for train, val and, test:

from atariari.benchmark.episodes import get_episodes

tr_episodes, val_episodes,\
tr_labels, val_labels,\
test_episodes, test_labels = get_episodes(env_name="PitfallNoFrameskip-v4", 
                                     steps=50000, 
                                     collect_mode="random_agent")

Then probe them using ProbeTrainer and your encoder (my_encoder):

from atariari.benchmark.probe import ProbeTrainer

probe_trainer = ProbeTrainer(my_encoder, representation_len=my_encoder.feature_size)
probe_trainer.train(tr_episodes, val_episodes,
                     tr_labels, val_labels,)
final_accuracies, final_f1_scores = probe_trainer.test(test_episodes, test_labels)

To see how we use ProbeTrainer, check out scripts/run_probe.py

Here is an example of my_encoder:

# get your encoder
import torch.nn as nn
import torch
class MyEncoder(nn.Module):
    def __init__(self, input_channels, feature_size):
        super().__init__()
        self.feature_size = feature_size
        self.input_channels = input_channels
        self.final_conv_size = 64 * 9 * 6
        self.cnn = nn.Sequential(
            nn.Conv2d(input_channels, 32, 8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 128, 4, stride=2),
            nn.ReLU(),
            nn.Conv2d(128, 64, 3, stride=1),
            nn.ReLU()
        )
        self.fc = nn.Linear(self.final_conv_size, self.feature_size)

    def forward(self, inputs):
        x = self.cnn(inputs)
        x = x.view(x.size(0), -1)
        return self.fc(x)
        

my_encoder = MyEncoder(input_channels=1,feature_size=256)
# load in weights
my_encoder.load_state_dict(torch.load(open("path/to/my/weights.pt", "rb")))

Spatio-Temporal DeepInfoMax:

src/ contains implementations of several representation learning methods, along with ST-DIM. Here's a sample usage:

python -m scripts.run_probe --method infonce-stdim --env-name {env_name}

where env_name is of the form {game}NoFrameskip-v4, such as PongNoFrameskip-v4

Citation

@article{anand2019unsupervised,
  title={Unsupervised State Representation Learning in Atari},
  author={Anand, Ankesh and Racah, Evan and Ozair, Sherjil and Bengio, Yoshua and C{\^o}t{\'e}, Marc-Alexandre and Hjelm, R Devon},
  journal={arXiv preprint arXiv:1906.08226},
  year={2019}
}

More Repositories

1

blocks

A Theano framework for building and training neural networks
Python
1,155
star
2

welcome_tutorials

Various tutorials given for welcoming new students at MILA.
Jupyter Notebook
985
star
3

fuel

A data pipeline framework for machine learning
Python
864
star
4

babyai

BabyAI platform. A testbed for training agents to understand and execute language commands.
Python
681
star
5

myia

Myia prototyping
Python
455
star
6

summerschool2015

Slides and exercises for the Deep Learning Summer School 2015 programming tutorials
Jupyter Notebook
391
star
7

platoon

Multi-GPU mini-framework for Theano
Python
195
star
8

spr

Code for "Data-Efficient Reinforcement Learning with Self-Predictive Representations"
Python
155
star
9

blocks-examples

Examples and scripts using Blocks
Python
147
star
10

summerschool2016

Montrรฉal Deep Learning Summer School 2016 material
Jupyter Notebook
100
star
11

paperoni

Search for scientific papers on the command line
Python
97
star
12

summerschool2017

Material for the Montrรฉal Deep Learning Summer School 2017
Jupyter Notebook
78
star
13

gene-graph-conv

Towards Gene Expression Convolutions using Gene Interaction Graphs
Jupyter Notebook
74
star
14

milatools

Tools to connect to and interact with the Mila cluster
Python
60
star
15

Conscious-Planning

Implementation for paper "A Consciousness-Inspired Planning Agent for Model-Based Reinforcement Learning".
Python
58
star
16

SGI

Official code for "Pretraining Representations For Data-Efficient Reinforcement Learning" (NeurIPS 2021)
Python
51
star
17

ddxplus

Python
48
star
18

picklable-itertools

itertools. But picklable.
Python
38
star
19

climate-cooperation-competition

AI for Global Climate Cooperation: Modeling Global Climate Negotiations, Agreements, and Long-Term Cooperation in RICE-N. ai4climatecoop.org
Python
35
star
20

ivado-mila-dl-school-2019

IVADO/ Mila's Summer Deep Learning School
Jupyter Notebook
35
star
21

ivado-mila-dl-school-2021

Jupyter Notebook
33
star
22

blocks-extras

A collection of extensions to the Blocks framework
Python
27
star
23

DeepDrummer

Making the world a better place through AI-generated beats & grooves
Python
26
star
24

covid_p2p_risk_prediction

COVID19 P2P Risk Prediction Model & Dataset
Python
22
star
25

COVI-AgentSim

Covid-19 spread simulator with human mobility and intervention modeling.
Jupyter Notebook
20
star
26

Skipper

A PyTorch Implementation of Skipper
Python
20
star
27

cookiecutter-pyml

Python
19
star
28

milabench

Repository of machine learning benchmarks
Python
17
star
29

snektalk

Python
15
star
30

dlschool-ivadofr-a18

Ecole Mila/IVADO
Jupyter Notebook
12
star
31

COVI-ML

Risk model training code for Covid-19 tracing application.
Python
12
star
32

teamgrid

Multiagent gridworld for the TEAM project based on gym-minigrid
Python
12
star
33

ivado-mila-dl-school-2019-vancouver

Jupyter Notebook
11
star
34

mila-paper-webpage

Webpage template for MILA-affiliated papers
CSS
11
star
35

dlschool-ivadofr-h18

Ivado ร‰cole d'hiver IVADO/MILA en apprentissage profond 2018
Jupyter Notebook
10
star
36

giving

Reactive logging
Python
9
star
37

training

Python
8
star
38

mila-docs

Mila technical documentation
8
star
39

Casande-RL

Casande-RL
Python
8
star
40

hardpicks

Deep learning dataset and benchmark for first-break detection from hardrock seismic reflection data
Python
7
star
41

ptera

Query and override internal variables in your programs
Python
5
star
42

ResearchTemplate

WIP: Research Template Repository
Python
5
star
43

mila_datamodules

Efficient Datamodules Customized for the Mila / CC clusters
Python
4
star
44

digit-detection

IFT6759 - Advanced projects in machine learning (Door Number Detection project)
Shell
4
star
45

Humanitarian_R-D

Jupyter Notebook
3
star
46

ansible-role-clockwork

Ansible role to install and configure clockwork
Jinja
3
star
47

SARC

FD#11499
Python
3
star
48

ansible-role-cobbler

Install and configure Cobbler service
Jinja
3
star
49

slurm-queue-time-pred

Slurm wait time prediction
Python
3
star
50

diffusion_for_multi_scale_molecular_dynamics

Python
3
star
51

cableinspect-ad-code

Code to prepare data and reproduce results from CableInspect-AD paper
Python
3
star
52

clockwork

Simple metrics to monitor slurm and produce reports.
Python
2
star
53

ansible-role-infiniband

Ansible role to configure InfiniBand interfaces
Jinja
2
star
54

tensorflow_dataloader

1
star
55

bcachefs

C implementation with Python 3.7 bindings of the BCacheFS
C
1
star
56

ansible-collection-proxmox

Ansible Collection to manage containers and virtual machines with Proxmox VE
1
star
57

mila-docs-chatbot

Python
1
star