• Stars
    star
    494
  • Rank 89,130 (Top 2 %)
  • Language
    Python
  • License
    Apache License 2.0
  • Created over 6 years ago
  • Updated over 3 years ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Implementation for the paper "Compositional Attention Networks for Machine Reasoning" (Hudson and Manning, ICLR 2018)

Compostional Attention Networks for Real-World Reasoning

Drew A. Hudson & Christopher D. Manning

Please note: We have updated the GQA challenge deadline to be May 15. Best of Luck! :)

This is the implementation of Compositional Attention Networks for Machine Reasoning (ICLR 2018) on two visual reasoning datasets: CLEVR dataset and the New GQA dataset (CVPR 2019). We propose a fully differentiable model that learns to perform multi-step reasoning. See our website and blogpost for more information about the model!

In particular, the implementation includes the MAC cell at mac_cell.py. The code supports the standard cell as presented in the paper as well as additional extensions and variants. Run python main.py -h or see config.py for the complete list of options.

The adaptation of MAC as well as several baselines for the GQA dataset are located at the GQA branch.

Bibtex

For MAC:

@inproceedings{hudson2018compositional,
  title={Compositional Attention Networks for Machine Reasoning},
  author={Hudson, Drew A and Manning, Christopher D},
  journal={International Conference on Learning Representations (ICLR)},
  year={2018}
}

For the GQA dataset:

@article{hudson2018gqa,
  title={GQA: A New Dataset for Real-World Visual Reasoning and Compositional Question Answering},
  author={Hudson, Drew A and Manning, Christopher D},
  journal={Conference on Computer Vision and Pattern Recognition (CVPR)},
  year={2019}
}

Requirements

  • Tensorflow (originally has been developed with 1.3 but should work for later versions as well).
  • We have performed experiments on Maxwell Titan X GPU. We assume 12GB of GPU memory.
  • See requirements.txt for the required python packages and run pip install -r requirements.txt to install them.

Pre-processing

Before training the model, we first have to download the CLEVR dataset and extract features for the images:

Dataset

To download and unpack the data, run the following commands:

wget https://dl.fbaipublicfiles.com/clevr/CLEVR_v1.0.zip
unzip CLEVR_v1.0.zip
mv CLEVR_v1.0 CLEVR_v1
mkdir CLEVR_v1/data
mv CLEVR_v1/questions/* CLEVR_v1/data/

The final command moves the dataset questions into the data directory, where we will put all the data files we use during training.

Feature extraction

Extract ResNet-101 features for the CLEVR train, val, and test images with the following commands:

python extract_features.py --input_image_dir CLEVR_v1/images/train --output_h5_file CLEVR_v1/data/train.h5 --batch_size 32
python extract_features.py --input_image_dir CLEVR_v1/images/val --output_h5_file CLEVR_v1/data/val.h5 --batch_size 32
python extract_features.py --input_image_dir CLEVR_v1/images/test --output_h5_file CLEVR_v1/data/test.h5 --batch_size 32

Training

To train the model, run the following command:

python main.py --expName "clevrExperiment" --train --testedNum 10000 --epochs 25 --netLength 4 @configs/args.txt

First, the program preprocesses the CLEVR questions. It tokenizes them and maps them to integers to prepare them for the network. It then stores a JSON with that information about them as well as word-to-integer dictionaries in the ./CLEVR_v1/data directory.

Then, the program trains the model. Weights are saved by default to ./weights/{expName} and statistics about the training are collected in ./results/{expName}, where expName is the name we choose to give to the current experiment.

Notes

  • The number of examples used for training and evaluation can be set by --trainedNum and --testedNum respectively.
  • You can use the -r flag to restore and continue training a previously pre-trained model.
  • We recommend you to try out varying the number of MAC cells used in the network through the --netLength option to explore different lengths of reasoning processes.
  • Good lengths for CLEVR are in the range of 4-16 (using more cells tends to converge faster and achieves a bit higher accuracy, while lower number of cells usually results in more easily interpretable attention maps).

Model variants

We have explored several variants of our model. We provide a few examples in configs/args2-4.txt. For instance, you can run the first by:

python main.py --expName "experiment1" --train --testedNum 10000 --epochs 40 --netLength 6 @configs/args2.txt
  • args2 uses a non-recurrent variant of the control unit that converges faster.
  • args3 incorporates self-attention into the write unit.
  • args4 adds control-based gating over the memory.

See config.py for further available options (Note that some of them are still in an experimental stage).

Evalutation

To evaluate the trained model, and get predictions and attention maps, run the following:

python main.py --expName "clevrExperiment" --finalTest --testedNum 10000 --netLength 16 -r --getPreds --getAtt @configs/args.txt

The command will restore the model we have trained, and evaluate it on the validation set. JSON files with predictions and the attention distributions resulted by running the model are saved by default to ./preds/{expName}.

  • In case you are interested in getting attention maps (--getAtt), and to avoid having large prediction files, we advise you to limit the number of examples evaluated to 5,000-20,000.

Visualization

After we evaluate the model with the command above, we can visualize the attention maps generated by running:

python visualization.py --expName "clevrExperiment" --tier val 

(Tier can be set to train or test as well). The script supports filtering of the visualized questions by various ways. See visualization.py for further details.

To get more interpretable visualizations, it is highly recommended to reduce the number of cells to 4-8 (--netLength). Using more cells allows the network to learn more effective ways to approach the task but these tend to be less interpretable compared to a shorter networks (with less cells).

Optionally, to make the image attention maps look a little bit nicer, you can do the following (using imagemagick):

for x in preds/clevrExperiment/*Img*.png; do magick convert $x -brightness-contrast 20x35 $x; done;

Thank you for your interest in our model! Please contact me at [email protected] for any questions, comments, or suggestions! :-)

More Repositories

1

dspy

DSPy: The framework for programming—not prompting—foundation models
Python
18,220
star
2

CoreNLP

CoreNLP: A Java suite of core NLP tools for tokenization, sentence segmentation, NER, parsing, coreference, sentiment analysis, etc.
Java
9,678
star
3

stanza

Stanford NLP Python library for tokenization, sentence segmentation, NER, and parsing of many human languages
Python
7,278
star
4

GloVe

Software in C and data files for the popular GloVe model for distributed word representations, a.k.a. word vectors or embeddings
C
6,867
star
5

cs224n-winter17-notes

Course notes for CS224N Winter17
TeX
1,587
star
6

pyreft

ReFT: Representation Finetuning for Language Models
Python
1,137
star
7

treelstm

Tree-structured Long Short-Term Memory networks (http://arxiv.org/abs/1503.00075)
Lua
875
star
8

pyvene

Stanford NLP Python Library for Understanding and Improving PyTorch Models via Interventions
Python
625
star
9

string2string

String-to-String Algorithms for Natural Language Processing
Jupyter Notebook
533
star
10

python-stanford-corenlp

Python interface to CoreNLP using a bidirectional server-client interface.
Python
516
star
11

phrasal

A large-scale statistical machine translation system written in Java.
Java
208
star
12

spinn

SPINN (Stack-augmented Parser-Interpreter Neural Network): fast, batchable, context-aware TreeRNNs
Python
205
star
13

coqa-baselines

The baselines used in the CoQA paper
Python
176
star
14

cocoa

Framework for learning dialogue agents in a two-player game setting.
Python
158
star
15

stanza-old

Stanford NLP group's shared Python tools.
Python
138
star
16

chirpycardinal

Stanford's Alexa Prize socialbot
Python
131
star
17

stanfordnlp

[Deprecated] This library has been renamed to "Stanza". Latest development at: https://github.com/stanfordnlp/stanza
Python
114
star
18

wge

Workflow-Guided Exploration: sample-efficient RL agent for web tasks
Python
109
star
19

pdf-struct

Logical structure analysis for visually structured documents
Python
81
star
20

edu-convokit

Edu-ConvoKit: An Open-Source Framework for Education Conversation Data
Jupyter Notebook
75
star
21

cs224n-web

http://cs224n.stanford.edu
HTML
60
star
22

ColBERT-QA

Code for Relevance-guided Supervision for OpenQA with ColBERT (TACL'21)
40
star
23

stanza-train

Model training tutorials for the Stanza Python NLP Library
Python
37
star
24

phrasenode

Mapping natural language commands to web elements
Python
37
star
25

contract-nli-bert

A baseline system for ContractNLI (https://stanfordnlp.github.io/contract-nli/)
Python
29
star
26

color-describer

Code for Learning to Generate Compositional Color Descriptions
OpenEdge ABL
26
star
27

stanza-resources

23
star
28

python-corenlp-protobuf

Python bindings for Stanford CoreNLP's protobufs.
Python
20
star
29

miniwob-plusplus-demos

Demos for the MiniWoB++ benchmark
17
star
30

multi-distribution-retrieval

Code for our paper Resources and Evaluations for Multi-Distribution Dense Information Retrieval
Python
14
star
31

huggingface-models

Scripts for pushing models to huggingface repos
Python
11
star
32

nlp-meetup-demo

Java
8
star
33

sentiment-treebank

Updated version of SST
Python
8
star
34

en-worldwide-newswire

An English NER dataset built from foreign newswire
Python
7
star
35

plot-data

datasets for plotting
Jupyter Notebook
6
star
36

contract-nli

ContractNLI: A Dataset for Document-level Natural Language Inference for Contracts
HTML
4
star
37

plot-interface

Web interface for the plotting project
JavaScript
3
star
38

handparsed-treebank

Extra hand parsed data for training models
Perl
2
star
39

coqa

CoQA -- A Conversational Question Answering Challenge
Shell
2
star
40

pdf-struct-models

A repository for hosting models for https://github.com/stanfordnlp/pdf-struct
HTML
2
star
41

chirpy-parlai-blenderbot-fork

A fork of ParlAI supporting Chirpy Cardinal's custom neural generator
Python
2
star
42

wob-data

Data for QAWoB and FlightWoB web interaction benchmarks from the World of Bits paper (Shi et al., 2017).
Python
2
star
43

pdf-struct-dataset

Dataset for pdf-struct (https://github.com/stanfordnlp/pdf-struct)
HTML
1
star
44

nn-depparser

A re-implementation of nndep using PyTorch.
Python
1
star