• Stars
    star
    174
  • Rank 211,537 (Top 5 %)
  • Language
    Python
  • License
    MIT License
  • Created over 4 years ago
  • Updated about 1 year ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Implementation of the Object Relation Transformer for Image Captioning

Object Relation Transformer

This is a PyTorch implementation of the Object Relation Transformer published in NeurIPS 2019. You can find the paper here. This repository is largely based on code from Ruotian Luo's Self-critical Sequence Training for Image Captioning GitHub repo, which can be found here.

The primary additions are as follows:

  • Relation transformer model
  • Script to create reports for runs on MSCOCO

Requirements

  • Python 2.7 (because there is no coco-caption version for Python 3)
  • PyTorch 0.4+ (along with torchvision)
  • h5py
  • scikit-image
  • typing
  • pyemd
  • gensim
  • cider (already added as a submodule). See .gitmodules and clone the referenced repo into the object_relation_transformer folder.
  • The coco-caption library, which is used for generating different evaluation metrics. To set it up, clone the repo into the object_relation_transformer folder. Make sure to keep the cloned repo folder name as coco-caption and also to run the get_stanford_models.sh script from within that repo.

Data Preparation

Download ResNet101 weights for feature extraction

Download the file resnet101.pth from here. Copy the weights to a folder imagenet_weights within the data folder:

mkdir data/imagenet_weights
cp /path/to/downloaded/weights/resnet101.pth data/imagenet_weights

Download and preprocess the COCO captions

Download the preprocessed COCO captions from Karpathy's homepage. Extract dataset_coco.json from the zip file and copy it in to data/. This file provides preprocessed captions and also standard train-val-test splits.

Then run:

$ python scripts/prepro_labels.py --input_json data/dataset_coco.json --output_json data/cocotalk.json --output_h5 data/cocotalk

prepro_labels.py will map all words that occur <= 5 times to a special UNK token, and create a vocabulary for all the remaining words. The image information and vocabulary are dumped into data/cocotalk.json and discretized caption data are dumped into data/cocotalk_label.h5.

Next run:

$ python scripts/prepro_ngrams.py --input_json data/dataset_coco.json --dict_json data/cocotalk.json --output_pkl data/coco-train --split train

This will preprocess the dataset and get the cache for calculating cider score.

Download the COCO dataset and pre-extract the image features

Download the COCO images from the MSCOCO website. We need 2014 training images and 2014 validation images. You should put the train2014/ and val2014/ folders in the same directory, denoted as $IMAGE_ROOT:

mkdir $IMAGE_ROOT
pushd $IMAGE_ROOT
wget http://images.cocodataset.org/zips/train2014.zip
unzip train2014.zip
wget http://images.cocodataset.org/zips/val2014.zip
unzip val2014.zip
popd
wget https://msvocds.blob.core.windows.net/images/262993_z.jpg
mv 262993_z.jpg $IMAGE_ROOT/train2014/COCO_train2014_000000167126.jpg

The last two commands are needed to address an issue with a corrupted image in the MSCOCO dataset (see here). The prepro script will fail otherwise.

Then run:

$ python scripts/prepro_feats.py --input_json data/dataset_coco.json --output_dir data/cocotalk --images_root $IMAGE_ROOT

prepro_feats.py extracts the ResNet101 features (both fc feature and last conv feature) of each image. The features are saved in data/cocotalk_fc and data/cocotalk_att, and resulting files are about 200GB. Running this script may take a day or more, depending on hardware.

(Check the prepro scripts for more options, like other ResNet models or other attention sizes.)

Download the Bottom-up features

Download the pre-extracted features from here. For the paper, the adaptive features were used.

Do the following:

mkdir data/bu_data; cd data/bu_data
wget https://imagecaption.blob.core.windows.net/imagecaption/trainval.zip
unzip trainval.zip

The .zip file is around 22 GB. Then return to the base directory and run:

python scripts/make_bu_data.py --output_dir data/cocobu

This will create data/cocobu_fc, data/cocobu_att and data/cocobu_box.

Generate the relative bounding box coordinates for the Relation Transformer

Run the following:

python scripts/prepro_bbox_relative_coords.py --input_json data/dataset_coco.json --input_box_dir data/cocobu_box --output_dir data/cocobu_box_relative --image_root $IMAGE_ROOT

This should take a couple hours or so, depending on hardware.

Model Training and Evaluation

Standard cross-entropy loss training

python train.py --id relation_transformer_bu --caption_model relation_transformer --input_json data/cocotalk.json --input_fc_dir data/cocobu_fc --input_att_dir data/cocobu_att --input_box_dir data/cocobu_box --input_rel_box_dir data/cocobu_box_relative --input_label_h5 data/cocotalk_label.h5 --checkpoint_path log_relation_transformer_bu --noamopt --noamopt_warmup 10000 --label_smoothing 0.0 --batch_size 15 --learning_rate 5e-4 --num_layers 6 --input_encoding_size 512 --rnn_size 2048 --learning_rate_decay_start 0 --scheduled_sampling_start 0 --save_checkpoint_every 6000 --language_eval 1 --val_images_use 5000 --max_epochs 30 --use_box 1

The train script will dump checkpoints into the folder specified by --checkpoint_path (default = save/). We only save the best-performing checkpoint on validation and the latest checkpoint to save disk space.

To resume training, you can specify --start_from option to be the path saving infos.pkl and model.pth (usually you could just set --start_from and --checkpoint_path to be the same).

If you have tensorflow, the loss histories are automatically dumped into --checkpoint_path, and can be visualized using tensorboard.

The current command uses scheduled sampling. You can also set scheduled_sampling_start to -1 to disable it.

If you'd like to evaluate BLEU/METEOR/CIDEr scores during training in addition to validation cross entropy loss, use --language_eval 1 option, but don't forget to download the coco-caption code into coco-caption directory.

For more options, see opts.py.

The above training script should achieve a CIDEr-D score of about 115.

Self-critical RL training

After training using cross-entropy loss, additional self-critical training produces signficant gains in CIDEr-D score.

First, copy the model from the pretrained model using cross entropy. (It's not mandatory to copy the model, just for back-up)

$ bash scripts/copy_model.sh relation_transformer_bu relation_transformer_bu_rl

Then:

python train.py --id relation_transformer_bu_rl --caption_model relation_transformer --input_json data/cocotalk.json --input_fc_dir data/cocobu_fc --input_att_dir data/cocobu_att --input_label_h5 data/cocotalk_label.h5  --input_box_dir data/cocobu_box --input_rel_box_dir data/cocobu_box_relative --input_label_h5 data/cocotalk_label.h5 --checkpoint_path log_relation_transformer_bu_rl --label_smoothing 0.0 --batch_size 10 --learning_rate 5e-4 --num_layers 6 --input_encoding_size 512 --rnn_size 2048 --learning_rate_decay_start 0 --scheduled_sampling_start 0 --start_from log_transformer_bu_rl --save_checkpoint_every 6000 --language_eval 1 --val_images_use 5000 --self_critical_after 30 --max_epochs 60 --use_box 1

The above training script should achieve a CIDEr-D score of about 128.

Evaluate on Karpathy's test split

To evaluate the cross-entropy model, run:

python eval.py --dump_images 0 --num_images 5000 --model log_relation_transformer_bu/model.pth --infos_path log_relation_transformer_bu/infos_relation_transformer_bu-best.pkl --image_root $IMAGE_ROOT --input_json data/cocotalk.json --input_label_h5 data/cocotalk_label.h5  --input_fc_dir data/cocobu_fc --input_att_dir data/cocobu_att --input_box_dir data/cocobu_box --input_rel_box_dir data/cocobu_box_relative --use_box 1 --language_eval 1

and for cross-entropy+RL run:

python eval.py --dump_images 0 --num_images 5000 --model log_relation_transformer_bu_rl/model.pth --infos_path log_relation_transformer_bu_rl/infos_relation_transformer_bu-best.pkl --image_root $IMAGE_ROOT --input_json data/cocotalk.json --input_label_h5 data/cocotalk_label.h5  --input_fc_dir data/cocobu_fc --input_att_dir data/cocobu_att --input_box_dir data/cocobu_box --input_rel_box_dir data/cocobu_box_relative --language_eval 1

Visualization

Visualize caption predictions

Place all your images of interest into a folder, e.g. images, and run the eval script:

$ python eval.py --dump_images 1 --num_images 10 --model log_relation_transformer_bu/model.pth --infos_path log_relation_transformer_bu/infos_relation_transformer_bu-best.pkl --image_root $IMAGE_ROOT --input_json data/cocotalk.json --input_label_h5 data/cocotalk_label.h5  --input_fc_dir data/cocobu_fc --input_att_dir data/cocobu_att --input_box_dir data/cocobu_box --input_rel_box_dir data/cocobu_box_relative

This tells the eval script to run up to 10 images from the given folder. If you have a big GPU you can speed up the evaluation by increasing batch_size. Use --num_images -1 to process all images. The eval script will create an vis.json file inside the vis folder, which can then be visualized with the provided HTML interface:

$ cd vis
$ python -m SimpleHTTPServer

Now visit localhost:8000 in your browser and you should see your predicted captions.

Generate reports from runs on MSCOCO

The create_report.py script can be used in order to generate HTML reports containing results from different runs. Please see the script for specific usage examples.

The script takes as input one or more pickle files containing results from runs on the MSCOCO dataset. It reads in the pickle files and creates a set of HTML files with tables and graphs generated from the different captioning evaluation metrics, as well as the generated image captions and corresponding metrics for individual images.

If more than one pickle file with results is provided as input, the script will also generate a report containing a comparison between the metrics generated by each pair of methods.

Model Zoo and Results

The table below presents links to our pre-trained models, as well as results from our paper on the Karpathy test split. Similar results should be obtained by running the respective commands in neurips_training_runs.sh. As learning rate scheduling was not fully optimized, these values should only serve as a reference/expectation rather than what can be achieved with additional tuning.

The models are Copyright Verizon Media, licensed under the terms of the CC-BY-4.0 license. See associated license file.

Algorithm CIDEr-D SPICE BLEU-1 BLEU-4 METEOR ROUGE-L
Up-Down + LSTM * 106.6 19.9 75.6 32.9 26.5 55.4
Up-Down + Transformer 111.0 20.9 75.0 32.8 27.5 55.6
Up-Down + Object Relation Transformer 112.6 20.8 75.6 33.5 27.6 56.0
Up-Down + Object Relation Transformer + Beamsize 2 115.4 21.2 76.6 35.5 28.0 56.6
Up-Down + Object Relation Transformer + Self-Critical + Beamsize 5 128.3 22.6 80.5 38.6 28.7 58.4

* Note that the pre-trained Up-Down + LSTM model above produces slightly better results than reported, as it came from a different training run. We kept the older LSTM results in the table above for consistency with our paper.

Comparative Analysis

In addition, in the paper we also present a head-to-head comparison of the Object Relation Transformer against the "Up-Down + Transformer" model. (Results from the latter model are also included in the table above). In the paper, we refer to this latter model as "Baseline Transformer", as it does not make use of geometry in its attention definition. The idea of the head-to-head comparison is to better understand the improvement obtained by adding geometric attention to the Transformer, both quantitatively and qualitatively. The comparison consists of a set of evaluation metrics computed for each model on a per-image basis, as well as aggregated over all images. It includes the results of paired t-tests, which test for statistically significant differences between the evaluation metrics resulting from each of the models. This comparison can be generated by running the commands in neurips_report_comands.sh. The commands first run the two aforementioned models on the MSCOCO test set and then generate the corresponding report containing the complete comparative analysis.

Citation

If you find this repo useful, please consider citing (no obligation at all):

@article{herdade2019image,
  title={Image Captioning: Transforming Objects into Words},
  author={Herdade, Simao and Kappeler, Armin and Boakye, Kofi and Soares, Joao},
  journal={arXiv preprint arXiv:1906.05963},
  year={2019}
}

Of course, please cite the original paper of models you are using (you can find references in the model files).

Contribute

Please refer to the contributing.md file for information about how to get involved. We welcome issues, questions, and pull requests.

Please be aware that we (the maintainers) are currently busy with other projects, so it make take some days before we are able to get back to you. We do not foresee big changes to this repository going forward.

Maintainers

Kofi Boakye: [email protected]

Simao Herdade: [email protected]

Joao Soares: [email protected]

License

This project is licensed under the terms of the MIT open source license. Please refer to LICENSE for the full terms.

Acknowledgments

Thanks to Ruotian Luo for the original code.

More Repositories

1

CMAK

CMAK is a tool for managing Apache Kafka clusters
Scala
11,676
star
2

open_nsfw

Not Suitable for Work (NSFW) classification using deep neural network Caffe models.
Python
5,791
star
3

TensorFlowOnSpark

TensorFlowOnSpark brings TensorFlow programs to Apache Spark clusters.
Python
3,860
star
4

serialize-javascript

Serialize JavaScript to a superset of JSON that includes regular expressions and functions.
JavaScript
2,785
star
5

gryffin

Gryffin is a large scale web security scanning platform.
Go
2,075
star
6

fluxible

A pluggable container for universal flux applications.
JavaScript
1,815
star
7

AppDevKit

AppDevKit is an iOS development library that provides developers with useful features to fulfill their everyday iOS app development needs.
Objective-C
1,439
star
8

mysql_perf_analyzer

MySQL performance monitoring and analysis.
Java
1,436
star
9

squidb

SquiDB is a SQLite database library for Android and iOS
Java
1,313
star
10

CaffeOnSpark

Distributed deep learning on Hadoop and Spark clusters.
Jupyter Notebook
1,262
star
11

react-stickynode

A performant and comprehensive React sticky component.
JavaScript
1,227
star
12

blink-diff

A lightweight image comparison tool.
JavaScript
1,191
star
13

egads

A Java package to automatically detect anomalies in large scale time-series data
Java
1,152
star
14

elide

Elide is a Java library that lets you stand up a GraphQL/JSON-API web service with minimal effort.
Java
985
star
15

vssh

Go Library to Execute Commands Over SSH at Scale
Go
930
star
16

webseclab

set of web security test cases and a toolkit to construct new ones
Go
915
star
17

kubectl-flame

Kubectl plugin for effortless profiling on kubernetes
Go
746
star
18

streaming-benchmarks

Benchmarks for Low Latency (Streaming) solutions including Apache Storm, Apache Spark, Apache Flink, ...
Jupyter Notebook
621
star
19

lopq

Training of Locally Optimized Product Quantization (LOPQ) models for approximate nearest neighbor search of high dimensional data in Python and Spark.
Python
558
star
20

redislite

Redis in a python module.
Python
556
star
21

HaloDB

A fast, log structured key-value store.
Java
486
star
22

hecate

Automagically generate thumbnails, animated GIFs, and summaries from videos
C++
468
star
23

fetchr

Universal data access layer for web applications.
JavaScript
447
star
24

storm-yarn

Storm-yarn enables Storm clusters to be deployed into machines managed by Hadoop YARN.
Java
418
star
25

react-i13n

A performant, scalable and pluggable approach to instrumenting your React application.
JavaScript
383
star
26

FEL

Fast Entity Linker Toolkit for training models to link entities to KnowledgeBase (Wikipedia) in documents and queries.
Java
334
star
27

monitr

A Node.js process monitoring tool.
C++
312
star
28

Oak

A Scalable Concurrent Key-Value Map for Big Data Analytics
Java
266
star
29

TDOAuth

A BSD-licensed single-header-single-source OAuth1 implementation.
Swift
250
star
30

routr

A component that provides router related functionalities for both client and server.
JavaScript
246
star
31

mysql_partition_manager

MySQL Partition Manager
SQLPL
210
star
32

l3dsr

Direct Server Return load balancing across Layer 3 boundaries.
Shell
190
star
33

dnscache

dnscache for Node
JavaScript
184
star
34

check-log4j

To determine if a host is vulnerable to log4j CVE‐2021‐44228
Shell
173
star
35

fili

Easily make RESTful web services for time series reporting with Big Data analytics engines like Druid and SQL Databases.
Java
171
star
36

sherlock

Sherlock is an anomaly detection service built on top of Druid
Java
149
star
37

YMTreeMap

High performance Swift treemap layout engine for iOS and macOS.
Swift
129
star
38

maha

A framework for rapid reporting API development; with out of the box support for high cardinality dimension lookups with druid.
Scala
127
star
39

covid-19-data

COVID-19 datasets are constructed entirely from primary (government and public agency) sources
110
star
40

subscribe-ui-event

Subscribe-ui-event provides a cross-browser and performant way to subscribe to browser UI Events.
JavaScript
109
star
41

jafar

🌟!(Just another form application renderer)
JavaScript
109
star
42

panoptes

A Global Scale Network Telemetry Ecosystem
Python
98
star
43

reginabox

Registry In A Box
JavaScript
97
star
44

preceptor

Test runner and aggregator
JavaScript
85
star
45

hive-funnel-udf

Hive UDFs for funnel analysis
Java
85
star
46

SparkADMM

Generic Implementation of Consensus ADMM over Spark
Python
82
star
47

react-cartographer

Generic component for displaying Yahoo / Google / Bing maps.
JavaScript
82
star
48

graphkit

A lightweight Python module for creating and running ordered graphs of computations.
Python
80
star
49

storm-perf-test

A simple storm performance/stress test
Java
76
star
50

UDPing

UDPing measures latency and packet loss across a link.
C++
72
star
51

bgjs

TypeScript
66
star
52

YMCache

YMCache is a lightweight object caching solution for iOS and Mac OS X that is designed for highly parallel access scenarios.
Objective-C
63
star
53

ycb

A multi-dimensional configuration library that builds bundles from resource files describing a variety of values.
JavaScript
63
star
54

ariel

Ariel is an AWS Lambda designed to collect, analyze, and make recommendations about Reserved Instances for EC2.
Python
62
star
55

validatar

Functional testing framework for Big Data pipelines.
Java
58
star
56

imapnio

Java imap nio client that is designed to scale well for thousands of connections per machine and reduce contention when using large number of threads and cpus.
Java
54
star
57

serviceping

A ping like utility for tcp services
Python
50
star
58

express-busboy

A simple body-parser like module for express that uses connect-busboy under the hood.
JavaScript
44
star
59

covid-19-api

Yahoo Knowledge COVID-19 API provides JSON-API and GraphQL interfaces to access COVID-19 publicly sourced data
JavaScript
43
star
60

proxy-verifier

Proxy Verifier is an HTTP replay tool designed to verify the behavior of HTTP proxies. It builds a verifier-client binary and a verifier-server binary which each read a set of YAML or JSON files that specify the HTTP traffic for the two to exchange.
C++
42
star
61

panoptes-stream

A cloud native distributed streaming network telemetry.
Go
41
star
62

yql-plus

The YQL+ parser, execution engine, and source SDK.
Java
40
star
63

context-parser

A robust HTML5 context parser that parses HTML 5 web pages and reports the execution context of each character.
HTML
40
star
64

cocoapods-blocklist

A CocoaPods plugin used to check a project against a list of pods that you do not want included in your build. Security is the primary use, but keeping specific pods that have conflicting licenses is another possible use.
Ruby
39
star
65

covid-19-dashboard

Source code for the Yahoo Knowledge Graph COVID-19 Dashboard
JavaScript
38
star
66

FmFM

Python
36
star
67

ember-gridstack

Ember components to build drag-and-drop multi-column grids powered by gridstack.js
JavaScript
36
star
68

VerizonVideoPartnerSDK-controls-ios

Public iOS implementation of the OneMobileSDK default custom controls interface... demonstrating how customers can implement their own custom video player controls.
Swift
35
star
69

k8s-namespace-guard

K8s - Admission controller for guarding namespace
Go
34
star
70

fluxible-action-utils

Utility methods to aid in writing actions for fluxible based applications.
JavaScript
34
star
71

parsec

A collection of libraries and utilities to simplify the process of building web service applications.
Java
34
star
72

mod_statuspage

Simple express/connect middleware to provide a status page with following details of the nodejs host.
JavaScript
32
star
73

bftkv

A distributed key-value storage that's tolerant to Byzantine fault.
JavaScript
30
star
74

protractor-retry

Use protractor features to automatically re-run failed tests with a specific configurable number of attempts.
JavaScript
28
star
75

cubed

Data Mart As A Service
Java
27
star
76

spivak

Python
27
star
77

jsx-test

An easy way to test your React Components (`.jsx` files).
JavaScript
27
star
78

ycb-java

YCB Java
Java
27
star
79

fluxible-immutable-utils

A mixin that provides a convenient interface for using Immutable.js inside react components.
JavaScript
25
star
80

SubdomainSleuth

Scanner to identify dangling DNS records and subdomain takeovers
Go
25
star
81

maaf

Modality-Agnostic Attention Fusion for visual search with text feedback
Python
25
star
82

node-limits

Simple express/connect middleware to set limit to upload size, set request timeout etc.
JavaScript
24
star
83

GitHub-Security-Alerts-Workflow

Automation to Incorporate GitHub Security Alerts Into your Business Workflow
Python
23
star
84

bandar-log

Monitoring tool to measure flow throughput of data sources and processing components that are part of Data Ingestion and ETL pipelines.
Scala
21
star
85

fumble

Simple error objects in node. Created specifically to be used with https://github.com/yahoo/fetchr and based on https://github.com/hapijs/boom
JavaScript
21
star
86

express-csp

Express extension for Content Security Policy
JavaScript
19
star
87

elide-js

Elide is a library that makes it easy to talk to a JSON API compliant backend.
JavaScript
18
star
88

Zake

A python package that works to provide a nice set of testing utilities for the kazoo library.
Python
18
star
89

npm-auto-version

Automatically generate new NPM versions based on Git tags when publishing
JavaScript
18
star
90

httpmi

An HTTP proxy for IPMI commands.
Python
17
star
91

hodman

Selenium object library
JavaScript
17
star
92

cerebro

JavaScript
17
star
93

SongbirdCharts

Allows for other apps to render accessible audio charts
Kotlin
17
star
94

Override

In app feature flag management
Swift
16
star
95

ychaos

YChaos - The Resilience Framework by Yahoo!
Python
16
star
96

elide-spring-boot-example

Spring Boot example using the Elide framework.
Java
15
star
97

parsec-libraries

Tools to simplify deploying web services with Parsec.
Java
15
star
98

node-info

Node environment information
JavaScript
14
star
99

NetCHASM

An Automated health checking and server status verification system.
C++
13
star
100

invirtualenv

Tool to deploy python virtualenvs
Python
13
star