• Stars
    star
    146
  • Rank 252,769 (Top 5 %)
  • Language
    Python
  • License
    BSD 3-Clause "New...
  • Created over 4 years ago
  • Updated over 1 year ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Resources for the "CTRLsum: Towards Generic Controllable Text Summarization" paper

CTRLsum

This is PyTorch implementation of the paper:

CTRLsum: Towards Generic Controllable Text Summarization
Junxian He, Wojciech Kryściński, Bryan McCann, Nazneen Rajani, Caiming Xiong
arXiv 2020

This repo includes instructions for using pretrained CTRLsum models as well as training new models.

CTRLsum is a generic controllable summarization system to manipulate text summaries given control tokens in the form of keywords or prefix. CTRLsum is also able to achieve strong (e.g. state-of-the-art on CNN/Dailymail) summarization performance in an uncontrolled setting.

🎥 Demo1: Hugging Face Spaces(to interactively generate using the pretrained model)

🎥 Demo2(to navigate the CTRLsum outputs used in our experiments)

Model checkpoints

Dataset Dowload
CNN/DailyMail download (.tar.gz)
arXiv download (.tar.gz)
BIGPATENT download (.tar.gz)

These checkpoints are also available in huggingface transformers, see details below.

Updates

April 09, 2022

@aliencaocao made a repo here on converting our pretrained taggers into ONNX to make it much faster to load and run inference.

October 07, 2021

Integrated to Huggingface Spaces with Gradio. See demo: Hugging Face Spaces

June 18, 2021

We released another Web UI Demo (here) to navigate most of CTRLsum outputs generated in the experiments of the paper.

Mar 22, 2021

Hyunwoong Ko made a python package, summarizers, based on CTRLsum. CTRLsum is also now supported in huggingface transformers credited to Hyunwoong Ko. Currently CTRLsum can be easily used with several lines of codes with these packages. See an example using huggingface transformers.

Dependencies

The code requires Python 3, PyTorch (>=1.4.0), and fairseq (the code is tested on this commit)

Install dependencies:

# manually install fairseq
git clone https://github.com/pytorch/fairseq

# this repo is tested on a commit of fairseq from May 2020:
# fad3cf0769843e767155f4d0af18a61b9a804f59
cd fairseq
git reset --hard fad3cf07

# the BART interface in fairseq does not support prefix-constrained decoding
# as of creating this README, thus we need to make several modifications to 
# fairseq before installing it
cp ../ctrlsum/fairseq_task.py fairseq/tasks/fairseq_task.py
cp ../ctrlsum/sequence_generator.py fairseq/
cp ../ctrlsum/hub_interface.py fairseq/models/bart/

# install fairseq
pip install --editable ./

cd ..

# install other requirements
pip install -r requirements.txt

Example Usage of Pretrained Models

Option 1. Generate summaries in an interactive way, users can specify the control tokens (keywords, prompts, or both):

CUDA_VISIBLE_DEVICES=xx python scripts/generate_bart_interactive.py --exp [checkpoint directory] \
	--dataset example_dataset \
	--src test.oraclewordnssource

The command above reads source articles from datasets/example_dataset/test.oraclewordnssource, users can then interact with the system in the commandline by inputting the id of examples to be shown, as well as the control tokens:

ctrlsum

Option 2. Generate summaries from a file which includes keywords:

# the following command generates summaries from `datasets/example_dataset/test.oraclewordnssource`
# the input data format is concatenated keywords and source with sep token, please refer to the 
# given example data files for examples
# the predicted summaries are saved into the checkpoint directory
CUDA_VISIBLE_DEVICES=xx python scripts/generate_bart.py --exp [checkpoint directory] \
	--dataset example_dataset \
	--src test.oraclewordnssource 

Option 3. Through Huggingface Transformers

Our pretrained model checkpoints are available in huggingface transformers, the model names are: hyunwoongko/ctrlsum-cnndm, hyunwoongko/ctrlsum-arxiv, and hyunwoongko/ctrlsum-bigpatent. An example code snippet (quoted from here):

1. Create models and tokenizers

>> from transformers import AutoModelForSeq2SeqLM, PreTrainedTokenizerFast

>>> model = AutoModelForSeq2SeqLM.from_pretrained("hyunwoongko/ctrlsum-cnndm")
>>> # model = AutoModelForSeq2SeqLM.from_pretrained("hyunwoongko/ctrlsum-arxiv")
>>> # model = AutoModelForSeq2SeqLM.from_pretrained("hyunwoongko/ctrlsum-bigpatent")

>>> tokenizer = PreTrainedTokenizerFast.from_pretrained("hyunwoongko/ctrlsum-cnndm")
>>> # tokenizer = PreTrainedTokenizerFast.from_pretrained("hyunwoongko/ctrlsum-arxiv")
>>> # tokenizer = PreTrainedTokenizerFast.from_pretrained("hyunwoongko/ctrlsum-bigpatent")

2. Unconditioned summarization

>>> data = tokenizer("My name is Kevin. I love dogs. I loved dogs from 1996. Today, I'm going to walk on street with my dogs", return_tensors="pt")
>>> input_ids, attention_mask = data["input_ids"], data["attention_mask"]
>>> tokenizer.batch_decode(model.generate(input_ids, attention_mask=attention_mask, num_beams=5))[0]
'</s>My name is Kevin. I loved dogs from 1996.</s>'

3. Conditioned summarization

  • You can input condition token using TOKEN => CONTENTS structure
>>> data = tokenizer("today plan => My name is Kevin. I love dogs. I loved dogs from 1996. Today, I'm going to walk on street with my dogs", return_tensors="pt")
>>> input_ids, attention_mask = data["input_ids"], data["attention_mask"]
>>> tokenizer.batch_decode(model.generate(input_ids, attention_mask=attention_mask, num_beams=5))[0]
"</s> Today, I'm going to walk on street with my dogs. I loved dogs from 1996</s>"

4. Prompt summarization

  • You can also input decoder_input_ids for input prompt.
>>> data = tokenizer("Q:What is my name? A: => My name is Kevin. I love dogs. I loved dogs from 1996. Today, I'm going to walk on street with my dogs", return_tensors="pt")
>>> input_ids, attention_mask = data["input_ids"], data["attention_mask"]
>>> tokenizer.batch_decode(model.generate(input_ids, attention_mask=attention_mask, num_beams=5, decoder_input_ids=tokenizer("Q:What is My name? A:", return_tensors="pt")["input_ids"][:, :-1]))[0]
'<s>Q:What is My name? A: Kevin.</s>'

Option 4. Through the Summarizers Python Package

The python package summarizers allows you to use the pretrained CTRLsum with several lines of code.

Train CTRLsum

Data Processing

Prepare your data files into datasets/[dataset name], which should consist of six data files as [train/val/test].[source/target]. These data files are raw text with each row representing one example. We take cnndm dataset as an example to preprocess the dataset (see here for instructions to obtain the cnndm dataset):

# this command runs the preprocessing pipeline including tokenization, truncation, and 
# keywords extraction. It will generate all required data files to train CTRLsum into 
# `datasets/cnndm`. Example obtained files can be found in `datasets/example_dataset`
# Some optional arguments can be found in preprocess.py
python scripts/preprocess.py cnndm --mode pipeline

# gpt2 encoding
bash scripts/gpt2_encode.sh cnndm

# binarize dataset for fairseq
bash scripts/binarize_dataset.sh cnndm

For the generated files in the datasets/cnndm, the suffix oracleword represents the keywords (after keyword dropout) file, oraclewordsource represents the concatenated keywords and source. oraclewordns represents the original keywords without keyword dropout. The .jsonl files are potentially used to train the tagger later.

Train the summarization model on multiple GPUs:

bash scripts/train_bart.sh -g [GPUs] -d [dataset name] -b [bart checkpoint path (.pt file)]

GPUs are GPU ids separated by ,. All our experiments are on 8 GPUs accumulating 8 gradient steps, resulting in an effective batch size of 1024x8x8 tokens in total. You propably need to increase the update_freq variable in train_bart.sh if you use less GPUs to match the effective batch size. The saved models are in dir checkpoint. The training arguments can be found in train_bart.sh.

Train the keyword tagger (optional):

Note that the keyword tagger is required only in uncontrolled summarization setting and certain control settings which require automatic keywords (like length control in the paper)

# this requires to give 4 gpus for training by default,
# you need to change the --nproc_per_node value if you 
# train with different number of gpus
bash scripts/train_seqlabel.sh -g [GPUs] -d [dataset name]

The effective batch size we used for different datasets can be found in the training script as number of gpus x batch x uddate_freq

Evaluate CTRLsum

Here we include evaluation for uncontrolled summarization settings.

Obtain automatic keywords from a trained tagger:

# run prediction from the tagger which outputs confidence values for every token
# `checkpoint directory` is the directory that contains the `pytorch_model.bin` checkpoint.
# the results are saved in the checkpoint directory as test_predictions.txt
bash scripts/train_seqlabel.sh -g [GPUs] -d [dataset name] -p [checkpoint directory]


# obtain keywords by selecting confident words, `threshold, maximum-word, and summary-size` 
# are three hyperparameters in this step, please check Appendix A in the paper for specific
# values we used for different datasets, the performance is relatively robust
# this command will yield a file `.predwordsource` in `datasets/[dataset name]` which can be
# used as input to the summarization model to obtain uncontrolled summaries
python scripts/preprocess.py [dataset name] \
		--split test \
		--mode process_tagger_prediction \
		--tag-pred [the tagger prediction file path, named as test_predictions.txt] \
		--threshold [confidence threshold] \
		--maximum-word [maximum number of keywords] \
		--summary-size [number of sentences from which to identify keywords]

Metrics:

We report ROUGE scores and BERTScore in the paper. The ROUGE scores in the paper are computed using files2rouge which is a wrapper of a wrapper of the original ROUGE perl scripts. Please refer to scripts/test_bart.sh for our evaluation script:

# you will need the Stanford corenlp java toolkit to run this, we use it for tokenization
# this script computes ROUGE and (optionally) BERTScore.
bash scripts/test_bart.sh -g [GPUs] -s [source file name, NOT full path] -d [dataset] -p [ctrlsum checkpoint directory]

Citation

@article{he2020ctrlsum,
title={{\{}CTRL{\}}sum: Towards Generic Controllable Text Summarization},
author={He, Junxian and Kry{\'s}ci{\'n}ski, Wojciech and McCann, Bryan and Rajani, Nazneen and Xiong, Caiming},
journal={arXiv},
year={2020}
}

More Repositories

1

LAVIS

LAVIS - A One-stop Library for Language-Vision Intelligence
Jupyter Notebook
9,587
star
2

CodeGen

CodeGen is a family of open-source model for program synthesis. Trained on TPU-v4. Competitive with OpenAI Codex.
Python
4,594
star
3

BLIP

PyTorch code for BLIP: Bootstrapping Language-Image Pre-training for Unified Vision-Language Understanding and Generation
Jupyter Notebook
3,879
star
4

akita

🚀 State Management Tailored-Made for JS Applications
TypeScript
3,442
star
5

Merlion

Merlion: A Machine Learning Framework for Time Series Intelligence
Python
3,355
star
6

ja3

JA3 is a standard for creating SSL client fingerprints in an easy to produce and shareable way.
Python
2,666
star
7

CodeT5

Home of CodeT5: Open Code LLMs for Code Understanding and Generation
Python
2,437
star
8

decaNLP

The Natural Language Decathlon: A Multitask Challenge for NLP
Python
2,301
star
9

TransmogrifAI

TransmogrifAI (pronounced trăns-mŏgˈrə-fī) is an AutoML library for building modular, reusable, strongly typed machine learning workflows on Apache Spark with minimal hand-tuning
Scala
2,234
star
10

policy_sentry

IAM Least Privilege Policy Generator
Python
1,986
star
11

cloudsplaining

Cloudsplaining is an AWS IAM Security Assessment tool that identifies violations of least privilege and generates a risk-prioritized report.
JavaScript
1,972
star
12

awd-lstm-lm

LSTM and QRNN Language Model Toolkit for PyTorch
Python
1,900
star
13

ctrl

Conditional Transformer Language Model for Controllable Generation
Python
1,766
star
14

lwc

⚡️ LWC - A Blazing Fast, Enterprise-Grade Web Components Foundation
JavaScript
1,619
star
15

WikiSQL

A large annotated semantic parsing corpus for developing natural language interfaces.
HTML
1,606
star
16

sloop

Kubernetes History Visualization
Go
1,457
star
17

CodeTF

CodeTF: One-stop Transformer Library for State-of-the-art Code LLM
Python
1,375
star
18

ALBEF

Code for ALBEF: a new vision-language pre-training method
Python
1,276
star
19

pytorch-qrnn

PyTorch implementation of the Quasi-Recurrent Neural Network - up to 16 times faster than NVIDIA's cuDNN LSTM
Python
1,255
star
20

ai-economist

Foundation is a flexible, modular, and composable framework to model socio-economic behaviors and dynamics with both agents and governments. This framework can be used in conjunction with reinforcement learning to learn optimal economic policies, as done by the AI Economist (https://www.einstein.ai/the-ai-economist).
Python
964
star
21

design-system-react

Salesforce Lightning Design System for React
JavaScript
919
star
22

jarm

Python
914
star
23

tough-cookie

RFC6265 Cookies and CookieJar for Node.js
TypeScript
858
star
24

OmniXAI

OmniXAI: A Library for eXplainable AI
Jupyter Notebook
853
star
25

reactive-grpc

Reactive stubs for gRPC
Java
826
star
26

xgen

Salesforce open-source LLMs with 8k sequence length.
Python
717
star
27

UniControl

Unified Controllable Visual Generation Model
Python
614
star
28

vulnreport

Open-source pentesting management and automation platform by Salesforce Product Security
HTML
593
star
29

hassh

HASSH is a network fingerprinting standard which can be used to identify specific Client and Server SSH implementations. The fingerprints can be easily stored, searched and shared in the form of a small MD5 fingerprint.
Python
529
star
30

progen

Official release of the ProGen models
Python
518
star
31

base-components-recipes

A collection of base component recipes for Lightning Web Components on Salesforce Platform
JavaScript
509
star
32

Argus

Time series monitoring and alerting platform.
Java
501
star
33

CodeRL

This is the official code for the paper CodeRL: Mastering Code Generation through Pretrained Models and Deep Reinforcement Learning (NeurIPS22).
Python
488
star
34

matchbox

Write PyTorch code at the level of individual examples, then run it efficiently on minibatches.
Python
488
star
35

PCL

PyTorch code for "Prototypical Contrastive Learning of Unsupervised Representations"
Python
483
star
36

DialogStudio

DialogStudio: Towards Richest and Most Diverse Unified Dataset Collection and Instruction-Aware Models for Conversational AI
Python
472
star
37

cove

Python
470
star
38

warp-drive

Extremely Fast End-to-End Deep Multi-Agent Reinforcement Learning Framework on a GPU (JMLR 2022)
Python
452
star
39

PyRCA

PyRCA: A Python Machine Learning Library for Root Cause Analysis
Python
408
star
40

observable-membrane

A Javascript Membrane implementation using Proxies to observe mutation on an object graph
TypeScript
368
star
41

DeepTime

PyTorch code for Learning Deep Time-index Models for Time Series Forecasting (ICML 2023)
Python
351
star
42

ULIP

Python
316
star
43

MultiHopKG

Multi-hop knowledge graph reasoning learned via policy gradient with reward shaping and action dropout
Jupyter Notebook
300
star
44

logai

LogAI - An open-source library for log analytics and intelligence
Python
298
star
45

CodeGen2

CodeGen2 models for program synthesis
Python
272
star
46

provis

Official code repository of "BERTology Meets Biology: Interpreting Attention in Protein Language Models."
Python
269
star
47

causalai

Salesforce CausalAI Library: A Fast and Scalable framework for Causal Analysis of Time Series and Tabular Data
Jupyter Notebook
256
star
48

jaxformer

Minimal library to train LLMs on TPU in JAX with pjit().
Python
255
star
49

EDICT

Jupyter Notebook
247
star
50

rules_spring

Bazel rule for building Spring Boot apps as a deployable jar
Starlark
224
star
51

ETSformer

PyTorch code for ETSformer: Exponential Smoothing Transformers for Time-series Forecasting
Python
221
star
52

TabularSemanticParsing

Translating natural language questions to a structured query language
Jupyter Notebook
220
star
53

themify

👨‍🎨 CSS Themes Made Easy. A robust, opinionated solution to manage themes in your web application
TypeScript
216
star
54

simpletod

Official repository for "SimpleTOD: A Simple Language Model for Task-Oriented Dialogue"
Python
212
star
55

grpc-java-contrib

Useful extensions for the grpc-java library
Java
208
star
56

GeDi

GeDi: Generative Discriminator Guided Sequence Generation
Python
207
star
57

aws-allowlister

Automatically compile an AWS Service Control Policy that ONLY allows AWS services that are compliant with your preferred compliance frameworks.
Python
207
star
58

generic-sidecar-injector

A generic framework for injecting sidecars and related configuration in Kubernetes using Mutating Webhook Admission Controllers
Go
203
star
59

mirus

Mirus is a cross data-center data replication tool for Apache Kafka
Java
201
star
60

CoST

PyTorch code for CoST: Contrastive Learning of Disentangled Seasonal-Trend Representations for Time Series Forecasting (ICLR 2022)
Python
196
star
61

factCC

Resources for the "Evaluating the Factual Consistency of Abstractive Text Summarization" paper
Python
192
star
62

runway-browser

Interactive visualization framework for Runway models of distributed systems
JavaScript
188
star
63

glad

Global-Locally Self-Attentive Dialogue State Tracker
Python
186
star
64

cloud-guardrails

Rapidly apply hundreds of security controls in Azure
HCL
181
star
65

ALPRO

Align and Prompt: Video-and-Language Pre-training with Entity Prompts
Python
177
star
66

densecap

Jupyter Notebook
176
star
67

kafka-junit

This library wraps Kafka's embedded test cluster, allowing you to more easily create and run integration tests using JUnit against a "real" kafka server running within the context of your tests. No need to stand up an external kafka cluster!
Java
167
star
68

booksum

Python
167
star
69

sfdx-lwc-jest

Run Jest against LWC components in SFDX workspace environment
JavaScript
162
star
70

hierarchicalContrastiveLearning

Python
149
star
71

cos-e

Commonsense Explanations Dataset and Code
Python
144
star
72

secure-filters

Anti-XSS Security Filters for EJS and More
JavaScript
138
star
73

metabadger

Prevent SSRF attacks on AWS EC2 via automated upgrades to the more secure Instance Metadata Service v2 (IMDSv2).
Python
129
star
74

dockerfile-image-update

A tool that helps you get security patches for Docker images into production as quickly as possible without breaking things
Java
127
star
75

Converse

Python
125
star
76

refocus

The Go-To Platform for Visualizing Service Health
JavaScript
125
star
77

CoMatch

Code for CoMatch: Semi-supervised Learning with Contrastive Graph Regularization
Python
117
star
78

BOLAA

Python
114
star
79

fsnet

Python
111
star
80

rng-kbqa

Python
110
star
81

near-membrane

JavaScript Near Membrane Library that powers Lightning Locker Service
TypeScript
110
star
82

botsim

BotSIM - a data-efficient end-to-end Bot SIMulation toolkit for evaluation, diagnosis, and improvement of commercial chatbots
Jupyter Notebook
108
star
83

bazel-eclipse

This repo holds two IDE projects. One is the Eclipse Feature for developing Bazel projects in Eclipse. The Bazel Eclipse Feature supports importing, building, and testing Java projects that are built using the Bazel build system. The other is the Bazel Java Language Server, which is a build integration for IDEs such as VS Code.
Java
108
star
84

MUST

PyTorch code for MUST
Python
103
star
85

bro-sysmon

How to Zeek Sysmon Logs!
Zeek
100
star
86

Timbermill

A better logging service
Java
99
star
87

AuditNLG

AuditNLG: Auditing Generative AI Language Modeling for Trustworthiness
Python
97
star
88

eslint-plugin-lwc

Official ESLint rules for LWC
JavaScript
96
star
89

best

🏆 Delightful Benchmarking & Performance Testing
TypeScript
95
star
90

craft

CRAFT removes the language barrier to create Kubernetes Operators.
Go
93
star
91

eslint-config-lwc

Opinionated ESLint configurations for LWC projects
JavaScript
93
star
92

online_conformal

Methods for online conformal prediction.
Jupyter Notebook
90
star
93

lobster-pot

Scans every git push to your Github organisations to find unwanted secrets.
Go
88
star
94

ml4ir

Machine Learning for Information Retrieval
Jupyter Notebook
85
star
95

violet-conversations

Sophisticated Conversational Applications/Bots
JavaScript
84
star
96

apex-mockery

Lightweight mocking library in Apex
Apex
83
star
97

fast-influence-functions

Python
83
star
98

MoPro

MoPro: Webly Supervised Learning
Python
79
star
99

TaiChi

Open source library for few shot NLP
Python
79
star
100

helm-starter-istio

An Istio starter template for Helm
Shell
78
star