Learning Conditioned Graph Structures for Interpretable Visual Question Answering
This code provides a pytorch implementation of our graph learning method for Visual Question Answering as described in Learning Conditioned Graph Structures for Interpretable Visual Question Answering
Model diagram
Examples of learned graph structures
Getting Started
Reference
If you use our code or any of the ideas from our paper please cite:
@article{learningconditionedgraph,
author = {Will Norcliffe-Brown and Efstathios Vafeias and Sarah Parisot},
title = {Learning Conditioned Graph Structures for Interpretable Visual Question Answering},
journal = {arXiv preprint arXiv:1806.07243},
year = {2018}
}
Requirements
Data
To download and unzip the required datasets, change to the data folder and run
$ cd data; python download_data.py
To preprocess the image data and text data the following commands can be executed respectively. (Setting the data variable to trainval or test for preprocess_image.py and train, val or test for preprocess_text.py depending on which dataset you want to preprocess)
$ python preprocess_image.py --data trainval; python preprocess_text.py --data train
Pretrained model
If you would like a pretrained model, one can be found here: example model. This model achieved 66.2% accuracy on test.
Training
To train a model on the train set with our default parameters run
$ python run.py --train
and to train a model on the train and validation set for evaluation on the test set run
$ python run.py --trainval
Models can be validated via
$ python run.py --eval --model_path path_to_your_model
and a json of results from the test set can be produced with
$ python run.py --test --model_path path_to_your_model
To reproduce our results train a model on the trainval set with the default parameters, run the test script and evaluate the json on the EvalAI website.
Authors
- Will Norcliffe-Brown
- Sarah Parisot
- Stathis Vafeias
License
This project is licensed under the Apache 2.0 license - see Apache license
Acknowledgements
Our code is based on this implementation of the 2017 VQA challenge winner https://github.com/markdtw/vqa-winner-cvprw-2017