• Stars
    star
    1,101
  • Rank 41,805 (Top 0.9 %)
  • Language
    Jupyter Notebook
  • License
    MIT License
  • Created about 5 years ago
  • Updated over 2 years ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Fast, general, and tested differentiable structured prediction in PyTorch

Torch-Struct: Structured Prediction Library

Tests Coverage Status

A library of tested, GPU implementations of core structured prediction algorithms for deep learning applications.

  • HMM / LinearChain-CRF
  • HSMM / SemiMarkov-CRF
  • Dependency Tree-CRF
  • PCFG Binary Tree-CRF
  • ...

Designed to be used as efficient batched layers in other PyTorch code.

Tutorial paper describing methodology.

Getting Started

!pip install -qU git+https://github.com/harvardnlp/pytorch-struct
# Optional CUDA kernels for FastLogSemiring
!pip install -qU git+https://github.com/harvardnlp/genbmm
# For plotting.
!pip install -q matplotlib
import torch
from torch_struct import DependencyCRF, LinearChainCRF
import matplotlib.pyplot as plt
def show(x): plt.imshow(x.detach())
# Make some data.
vals = torch.zeros(2, 10, 10) + 1e-5
vals[:, :5, :5] = torch.rand(5)
vals[:, 5:, 5:] = torch.rand(5) 
dist = DependencyCRF(vals.log())
show(dist.log_potentials[0])

png

# Compute marginals
show(dist.marginals[0])

png

# Compute argmax
show(dist.argmax.detach()[0])

png

# Compute scoring and enumeration (forward / inside)
log_partition = dist.partition
max_score = dist.log_prob(dist.argmax)
# Compute samples 
show(dist.sample((1,)).detach()[0, 0])

png

# Padding/Masking built into library.
dist = DependencyCRF(vals, lengths=torch.tensor([10, 7]))
show(dist.marginals[0])
plt.show()
show(dist.marginals[1])

png

png

# Many other structured prediction approaches
chain = torch.zeros(2, 10, 10, 10) + 1e-5
chain[:, :, :, :] = vals.unsqueeze(-1).exp()
chain[:, :, :, :] += torch.eye(10, 10).view(1, 1, 10, 10) 
chain[:, 0, :, 0] = 1
chain[:, -1,9, :] = 1
chain = chain.log()

dist = LinearChainCRF(chain)
show(dist.marginals.detach()[0].sum(-1))

png

Library

Full docs: http://nlp.seas.harvard.edu/pytorch-struct/

Current distributions implemented:

  • LinearChainCRF
  • SemiMarkovCRF
  • DependencyCRF
  • NonProjectiveDependencyCRF
  • TreeCRF
  • NeuralPCFG / NeuralHMM

Each distribution includes:

  • Argmax, sampling, entropy, partition, masking, log_probs, k-max

Extensions:

  • Integration with torchtext, pytorch-transformers, dgl
  • Adapters for generative structured models (CFG / HMM / HSMM)
  • Common tree structured parameterizations TreeLSTM / SpanLSTM

Low-level API:

Everything implemented through semiring dynamic programming.

  • Log Marginals
  • Max and MAP computation
  • Sampling through specialized backprop
  • Entropy and first-order semirings.

Examples

Citation

@misc{alex2020torchstruct,
    title={Torch-Struct: Deep Structured Prediction Library},
    author={Alexander M. Rush},
    year={2020},
    eprint={2002.00876},
    archivePrefix={arXiv},
    primaryClass={cs.CL}
}

This work was partially supported by NSF grant IIS-1901030.

More Repositories

1

annotated-transformer

An annotated implementation of the Transformer paper.
Jupyter Notebook
5,476
star
2

seq2seq-attn

Sequence-to-sequence model with LSTM encoder/decoders and attention
Lua
1,252
star
3

im2markup

Neural model for converting Image-to-Markup (by Yuntian Deng yuntiandeng.com)
Lua
1,194
star
4

sent-conv-torch

Text classification using a convolutional neural network.
Lua
447
star
5

namedtensor

Named Tensor implementation for Torch
Jupyter Notebook
439
star
6

var-attn

Latent Alignment and Variational Attention
Python
326
star
7

sent-summary

299
star
8

neural-template-gen

Python
261
star
9

struct-attn

Code for Structured Attention Networks https://arxiv.org/abs/1702.00887
Lua
235
star
10

NeuralSteganography

STEGASURAS: STEGanography via Arithmetic coding and Strong neURAl modelS
Python
182
star
11

urnng

Python
176
star
12

botnet-detection

Topological botnet detection datasets and graph neural network applications
Python
165
star
13

data2text

Lua
158
star
14

sa-vae

Python
155
star
15

compound-pcfg

Python
126
star
16

cascaded-generation

Cascaded Text Generation with Markov Transformers
Python
126
star
17

TextFlow

Python
115
star
18

boxscore-data

HTML
109
star
19

decomp-attn

Decomposable Attention Model for Sentence Pair Classification (from https://arxiv.org/abs/1606.01933)
Lua
95
star
20

encoder-agnostic-adaptation

Encoder-Agnostic Adaptation for Conditional Language Generation
Python
79
star
21

genbmm

CUDA kernels for generalized matrix-multiplication in PyTorch
Jupyter Notebook
78
star
22

DeepLatentNLP

60
star
23

nmt-android

Neural Machine Translation on Android
Lua
59
star
24

BSO

Lua
54
star
25

hmm-lm

Python
42
star
26

seq2seq-talk

TeX
38
star
27

Talk-Latent

TeX
31
star
28

regulatory-prediction

Code and Data to accompany "Dilated Convolutions for Modeling Long-Distance Genomic Dependencies", presented at the ICML 2017 Workshop on Computational Biology
Python
28
star
29

harvardnlp.github.io

JavaScript
26
star
30

strux

Python
18
star
31

lie-access-memory

Lua
17
star
32

annotated-attention

Jupyter Notebook
15
star
33

DataModules

A state-less module system for torch-like languages
Python
8
star
34

seq2seq-attn-web

CSS
8
star
35

rush-nlp

JavaScript
7
star
36

tutorial-deep-latent

TeX
7
star
37

MemN2N

Torch implementation of End-to-End Memory Networks (https://arxiv.org/abs/1503.08895)
Lua
6
star
38

image-extraction

Extract images from PDFs
Jupyter Notebook
4
star
39

paper-explorer

JavaScript
3
star
40

readcomp

Entity Tracking Improves Cloze-style Reading Comprehension
Python
3
star
41

banded

Sparse banded diagonal matrices for pytorch
Cuda
2
star
42

torax

Python
2
star
43

cs6741

HTML
2
star
44

simple-recs

Python
1
star
45

poser

Python
1
star
46

iclr

1
star
47

cs6741-materials

1
star