• Stars
    star
    2,101
  • Rank 21,969 (Top 0.5 %)
  • Language
    Python
  • License
    Apache License 2.0
  • Created over 2 years ago
  • Updated 9 months ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

A modular RL library to fine-tune language models to human preferences

πŸ€– RL4LMs πŸš€

A modular RL library to fine-tune language models to human preferences


We provide easily customizable building blocks for training language models including implementations of on-policy algorithms, reward functions, metrics, datasets and LM based actor-critic policies

Paper Link: https://arxiv.org/abs/2210.01241

Website Link: https://rl4lms.apps.allenai.org/

Thoroughly tested and benchmarked with over 2000 experiments πŸ”₯ (GRUE benchmark πŸ†) on a comprehensive set of:

  • 7 different Natural Language Processing (NLP) Tasks:
    • Summarization
    • Generative Commonsense Reasoning
    • IMDB Sentiment-based Text Continuation
    • Table-to-text generation
    • Abstractive Question Answering
    • Machine Translation
    • Dialogue Generation
  • Different types of NLG metrics (20+) which can be used as reward functions:
    • Lexical Metrics (eg: ROUGE, BLEU, SacreBLEU, METEOR)
    • Semantic Metrics (eg: BERTSCORE, BLEURT)
    • Task specific metrics (eg: PARENT, CIDER, SPICE)
    • Scores from pre-trained classifiers (eg: Sentiment scores)
  • On-policy algorithms of PPO, A2C, TRPO and novel NLPO (Natural Language Policy Optimization)
  • Actor-Critic Policies supporting causal LMs (eg. GPT-2/3) and seq2seq LMs (eg. T5, BART)

All of these building blocks can be customizable allowing users to train transformer-based LMs to optimize any arbitrary reward function on any dataset of their choice.

Recent updates (v0.2.0) on 23-Nov-22

  • Added daily dialog task
  • Fixed compatibility issues with some Seq2seq models such as BART, blendorbot etc
  • Implemented data parallel support
  • Refactored policy classes

Recent updates (v0.2.1)

  • Minor logging updates

Install

Local Installation

git clone https://github.com/allenai/RL4LMs.git
cd RL4LMs
pip install -e .

Docker

We provide also a Dockerfile for development using docker containers containing all the dependencies.

docker build . -t rl4lms

Additional dependencies

Optionally, coreNLP libraries are required for certain metric computations (eg. SPICE) which can be downloaded through cd rl4lms/envs/text_generation/caption_metrics/spice && bash get_stanford_models.sh


Quick Start - Train PPO/NLPO using pre-defined YAML configs

We provide a simple training API that can be invoked via train script that allows to train PPO, NLPO or a supervised model by using a config file (YAML).

For example, to train T5-base on CNN/DM summarization on PPO using Rouge-1 as reward function, you can run:

python scripts/training/train_text_generation.py --config_path scripts/training/task_configs/summarization/t5_ppo.yml

Config files for all tasks can be found here.

YAML file schema - Configuring building blocks

Config file contains details about hyper-parameter settings for building blocks which are described below:

  • Dataset/Task: Dataset containing samples with input prompts and reference sentences. Available datasets are found in the class DataPoolRegistry in registry. (See how to create your own dataset here)

    datapool:
      id: cnn_daily_mail
      args:
        prompt_prefix: "Summarize: "
  • Tokenizer - A pre-trained tokenizer that is used to (de)tokenize input and output sequences with settings for padding and truncation

    tokenizer:
      model_name: t5-base
      padding_side: left
      truncation_side: left
      pad_token_as_eos_token: False
  • Reward Function: Reward function which computes token-level scores at each time step of MDP. Available reward functions can be found in the class RewardFunctionRegistry. (See how to create your own reward function here)

    reward_fn:
      id: rouge
      args:
        rouge_type: "rouge1"
  • Environment: Configures a gym-style text generation environment which simulates MDP episodes. Rollouts are generated using train samples from dataset consisting of input and reference texts. Further, we wrap our env with SubProcVecEnv from stable-baselines that processes n_envs episodes in parallel using multi-processing to compute step-wise rewards.
    Further configuration settings include:

    • max_episode_length : max length of the episode
    • max_prompt_length - maximum length of the input text to consider
    • terminate_on_eos - whether to terminate the episode as soon as EOS action is performed
    • prompt_truncation_side - truncation side for the prompt text
    • context_start_token - id for context token (corresponds to initial token given to decoder in encoder-decoder models)
    env:
      n_envs: 10
      args:
        max_prompt_length: 512
        max_episode_length: 100
        terminate_on_eos: True
        prompt_truncation_side: "right"
        context_start_token: 0
  • On-policy alg: We provide implementations of 4 on-policy algorithms: PPO, NLPO, A2C and TRPO adapted from stable-baselines3 tailored to work with NLP tasks which can be used out-of-the-box with either a causal policy or a seq2seq LM policy. (See how to create your own on-policy algorithm or policy)

    • We also provide a supervised trainer for benchmarking purposes. Supervised Warm start models are already uploaded to Huggingface Hub and specified in the respective config files.

    • Hyper-parameters for the algorithm can be specified at alg/args.

    • Further, all RL algorithms use adaptive KL controller to keep the LM close to original LM by setting initial KL co-efficient (alg/kl_div/coeff) and target KL (alg/kl_div/target_kl).

    • We support two types of LM policy: causal LM policy (for decoder only models) and seq2seq LM policy (for encoder-decoder models). Further for NLPO, we also provide maskable variants of these. Policy implementations can be found here in and it can be attached to algorithms by specifying alg/policy/id and alg/policy/args

      alg:
        id: ppo
        args: 
          n_steps: 512
          batch_size: 64
          verbose: 1
          learning_rate: 0.000002
          n_epochs: 5
          ent_coef: 0.0
        kl_div:
          coeff: 0.001
          target_kl: 0.2
        policy:
          id: seq2seq_lm_actor_critic_policy
          args:
            model_name: t5-base
            apply_model_parallel: True
            prompt_truncation_side: "right"
            generation_kwargs:
              do_sample: True
              top_k: 50
              min_length: 50
              max_new_tokens: 100          
  • Trainer Config: We provide an On-policy trainer - a feature-complete wrapper that instantiates building blocks from their corresponding configs and provides an outer training loop consisting of train and eval iterations train_evaluation/n_iters.

    • Each iteration corresponds to performing updates with alg/args/n_steps x env/n_envs of the chosen algorithm.
    • For every eval_every iters, LM is evaluated on validation split using metrics listed in train_evaluation/metrics with generation kwargs provided in train_evaluation/generation_kwargs (this overrides rollout alg/policy/generation_kwargs for inference purposes only)
    # train and evaluation
    train_evaluation:
      eval_batch_size: 100
      n_iters: 100
      eval_every: 10
      save_every: 1
      metrics:
        - id: meteor
          args: {}
        - id: rouge
        - id: bleu
          args: {}
        - id: bert_score
          args:
            language: en
        - id: diversity
          args: {}
      generation_kwargs: 
        do_sample: True
        top_k: 0
        temperature: 0.7
        min_length: 50
        max_new_tokens: 100

Custom Building Blocks πŸ”§

RL4LMs provide complete customizability - with respect to adding new tasks/datasets, reward functions, evaluation metric, on-policy algorithms and actor-critic policies.

Adding dataset

Users can create their own datasets by sub-classing TextGenPool just by overriding prepare(cls, split: str, **args) -> 'TextGenPool': method to return an instance of TextGenPool. An example is shown below:

from rl4lms.data_pools.text_generation_pool import Sample, TextGenPool

class MyDataPool(TextGenPool):
   @classmethod
   def prepare(cls, split: str):
       .. 
       samples = []
       for ix, item in enumerate(..):
           sample = Sample(id=f"{split}_{ix}",
                           prompt_or_input_text=item["document"],
                           references=[item["target"]]
                           )
           samples.append(sample)
       pool_instance = cls(samples)
       return pool_instance

Adding reward function

Custom reward funtions can be implemented easily by sub-classing RewardFunction (a callable) which takes observation ($s$), next observation ($s'$), action ($a$), done (indicating whether episode is finished) and meta info (containing other information about textual input). Here, Observation is a data class object consisting of generated text (at a particular step), prompt text, context text (at that step), reference text which can be used to compute token-level or sentence level rewards.

from rl4lms.envs.text_generation.observation import Observation
from rl4lms.envs.text_generation.reward import RewardFunction


class MyRewardFunction(RewardFunction):
   def __init__(self, *args) -> None:
       super().__init__()

   def __call__(self, prev_observation: Observation,
                action: int,
                current_observation: Observation,
                done: bool,
                meta_info: Dict[str, Any] = None) -> float:
       if done:
           reward = ..
           return reward
       return 0

πŸ’‘ In addition to traditional NLG metrics, for quick prototyping, we provide two synthetic reward functions which trains LMs to generate numbers in increasing order and generate dates. These can be used to quickly test different algorithms and policies. Corresponding configs can be found here (numbers, dates)

Adding custom metrics

Users can create their own evaluation metric which then will be used to periodically evaluate the model on validation split of dataset. This can be done by sub-classing BaseMetric which takes prompt texts, generated texts, reference texts, meta_infos, current LM model, split name as inputs and returns a dict with metric name as key and value consisting of tuple of sentence-level scores and corpus level scores. An example is as follows:

from rl4lms.envs.text_generation.metric import BaseMetric

class MyMetric(BaseMetric):
   def __init__(self) -> None:
       super().__init__()

   def compute(self,
               prompt_texts: List[str],
               generated_texts: List[str],
               reference_texts: List[List[str]],
               meta_infos: List[Dict[str, Any]] = None,
               model: PreTrainedModel = None,
               split_name: str = None):
       metric_dict = {
           "custom_metrics/my_metric": ([0.4, 0.7, 0.9], 0.7)
       }
       return metric_dict

Adding custom on-policy algorithms

In addition to supported on-policy algorithms (PPO, NLPO, A2C,TRPO), users can implement their own on-policy algorithms with ease by sub-classing stable-baselines3's OnPolicyAlgorithm. Since we provide wrappers for on-policy algorithms that handles rollouts using LM policies, environment, computing rewards etc, users just need to implement train() method with custom loss functions.

from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm

class MyOnPolicyAlgorithm(OnPolicyAlgorithm):
    def __init__(**args):
        super().__init__(**args)

    def train(self) -> None:
        # train for n_epochs epochs
        for epoch in range(self.n_epochs):
            # Do a complete pass on the rollout buffer
            for rollout_data in self.rollout_buffer.get(self.batch_size):
              # compute loss

Adding custom policies

We provide LM based actor-critic policy implementations that wraps causal LM and seq2seq LMs. These can be also extended (for eg: use a different critic architecture) by overriding appropriate methods (eg. evaluate_actions())

Registry

Finally, just register your custom components by adding them to corresponding registry, after which they can be used directly from configs similar to pre-defined components πŸ‘‹

Crowdsourcing templates

We have provided the crowdsourcing templates we used on mechanical turk, along with example inputs in scripts/crowdworking_templates. You might find these a helpful starting point either for evaluating your own model's generations, or for gathering training data for a learned reward function.


Logging and Experiment Results

Additionally, we support WANDB logging and warm-starting of training by storing checkpoints and other training artifacts in a user-specified path. This is especially useful for running preemptible jobs on large, scheduled clusters.

Artifacts include (1) jsonl file containing rollout infos at specified intervals (2) jsonl file containing training infos at specified intervals (3) jsonl file containing validation metrics at specified intervals (4) jsonl file containing test metrics before and after training (5) json file with validation predictions at specified intervals (6) json file with test predictions before and after training (7) trained LM model (8) config json used to run the experiment

Complete usage is as follows:

WANDB_API_KEY=<YOUR-WANDB-API-KEY-HERE>  python scripts/training/train_text_generation.py \
--config_path <PATH-TO-CONFIG-FILE> \
--experiment_name <EXPERIMENT-NAME> \
--base_path_to_store_results <PATH-TO-STORE-RESULTS> \
--log_to_wandb

Citation

@inproceedings{Ramamurthy2022IsRL,
  title={Is Reinforcement Learning (Not) for Natural Language Processing?: Benchmarks, Baselines, and Building Blocks for Natural Language Policy Optimization},
  author={Rajkumar Ramamurthy and Prithviraj Ammanabrolu and Kiant{\'e} Brantley and Jack Hessel and Rafet Sifa and Christian Bauckhage and Hannaneh Hajishirzi and Yejin Choi},
  journal={arXiv preprint arXiv:2210.01241},
  url={https://arxiv.org/abs/2210.01241},
  year={2022}
}

Questions/Discussion/Ideas?

For discussion, questions, ideas exchange, join our slack channel Slack

More Repositories

1

allennlp

An open-source NLP research library, built on PyTorch.
Python
11,751
star
2

OLMo

Modeling, training, eval, and inference code for OLMo
Python
4,535
star
3

longformer

Longformer: The Long-Document Transformer
Python
2,022
star
4

bilm-tf

Tensorflow implementation of contextualized word representations from bi-directional language models
Python
1,621
star
5

scispacy

A full spaCy pipeline and models for scientific/biomedical documents.
Python
1,618
star
6

bi-att-flow

Bi-directional Attention Flow (BiDAF) network is a multi-stage hierarchical process that represents context at different levels of granularity and uses a bi-directional attention flow mechanism to achieve a query-aware context representation without early summarization.
Python
1,533
star
7

scibert

A BERT model for scientific text.
Python
1,495
star
8

open-instruct

Python
1,185
star
9

ai2thor

An open-source platform for Visual AI.
C#
1,160
star
10

dolma

Data and tools for generating and inspecting OLMo pre-training data.
Python
961
star
11

XNOR-Net

ImageNet classification using binary Convolutional Neural Networks
Lua
839
star
12

s2orc

S2ORC: The Semantic Scholar Open Research Corpus: https://www.aclweb.org/anthology/2020.acl-main.447/
Python
817
star
13

mmc4

MultimodalC4 is a multimodal extension of c4 that interleaves millions of images with text.
Python
793
star
14

scitldr

Python
734
star
15

objaverse-xl

πŸͺ Objaverse-XL is a Universe of 10M+ 3D Objects. Contains API Scripts for Downloading and Processing!
Python
701
star
16

papermage

library supporting NLP and CV research on scientific papers
Python
692
star
17

natural-instructions

Expanding natural instructions
Python
690
star
18

visprog

Official code for VisProg (CVPR 2023 Best Paper!)
Python
686
star
19

science-parse

Science Parse parses scientific papers (in PDF form) and returns them in structured form.
Java
611
star
20

pdffigures2

Given a scholarly PDF, extract figures, tables, captions, and section titles.
Scala
593
star
21

writing-code-for-nlp-research-emnlp2018

A companion repository for the "Writing code for NLP Research" Tutorial at EMNLP 2018
Python
558
star
22

tango

Organize your experiments into discrete steps that can be cached and reused throughout the lifetime of your research project.
Python
528
star
23

allennlp-models

Officially supported AllenNLP models
Python
521
star
24

specter

SPECTER: Document-level Representation Learning using Citation-informed Transformers
Python
506
star
25

dont-stop-pretraining

Code associated with the Don't Stop Pretraining ACL 2020 paper
Python
488
star
26

unified-io-2

Python
471
star
27

macaw

Multi-angle c(q)uestion answering
Python
451
star
28

lumos

Code and data for "Lumos: Learning Agents with Unified Data, Modular Design, and Open-Source LLMs"
Python
433
star
29

document-qa

Python
420
star
30

scholarphi

An interactive PDF reader.
Python
418
star
31

deep_qa

A deep NLP library, based on Keras / tf, focused on question answering (but useful for other NLP too)
Python
404
star
32

acl2018-semantic-parsing-tutorial

Materials from the ACL 2018 tutorial on neural semantic parsing
402
star
33

unifiedqa

UnifiedQA: Crossing Format Boundaries With a Single QA System
Python
384
star
34

pawls

Software that makes labeling PDFs easy.
Python
380
star
35

OLMoE

OLMoE: Open Mixture-of-Experts Language Models
Jupyter Notebook
374
star
36

kb

KnowBert -- Knowledge Enhanced Contextual Word Representations
Python
359
star
37

PeerRead

Data and code for Kang et al., NAACL 2018's paper titled "A Dataset of Peer Reviews (PeerRead): Collection, Insights and NLP Applications"
Python
354
star
38

reward-bench

RewardBench: the first evaluation tool for reward models.
Python
346
star
39

naacl2021-longdoc-tutorial

Python
342
star
40

openie-standalone

Quality information extraction at web scale. Edit
Scala
327
star
41

Holodeck

CVPR 2024: Language Guided Generation of 3D Embodied AI Environments.
Python
319
star
42

python-package-template

A template repo for Python packages
Python
318
star
43

allenact

An open source framework for research in Embodied-AI from AI2.
Python
316
star
44

ir_datasets

Provides a common interface to many IR ranking datasets.
Python
314
star
45

s2orc-doc2json

Parsers for scientific papers (PDF2JSON, TEX2JSON, JATS2JSON)
Python
302
star
46

acl2022-zerofewshot-tutorial

291
star
47

OLMo-Eval

Evaluation suite for LLMs
Python
280
star
48

procthor

🏘️ Scaling Embodied AI by Procedurally Generating Interactive 3D Houses
Python
257
star
49

fm-cheatsheet

Website for hosting the Open Foundation Models Cheat Sheet.
JavaScript
255
star
50

FineGrainedRLHF

Python
243
star
51

beaker-cli

A collaborative platform for rapid and reproducible research.
Go
230
star
52

comet-atomic-2020

Python
228
star
53

spv2

Science-parse version 2
Python
225
star
54

scifact

Data and models for the SciFact verification task.
Python
217
star
55

objaverse-rendering

πŸ“· Scripts for rendering Objaverse
Python
206
star
56

ScienceWorld

ScienceWorld is a text-based virtual environment centered around accomplishing tasks from the standardized elementary science curriculum.
Scala
197
star
57

unified-io-inference

Jupyter Notebook
196
star
58

allennlp-demo

Code for the AllenNLP demo.
TypeScript
191
star
59

citeomatic

A citation recommendation system that allows users to find relevant citations for their paper drafts. The tool is backed by Semantic Scholar's OpenCorpus dataset.
Jupyter Notebook
189
star
60

cartography

Dataset Cartography: Mapping and Diagnosing Datasets with Training Dynamics
Jupyter Notebook
188
star
61

savn

Learning to Learn how to Learn: Self-Adaptive Visual Navigation using Meta-Learning (https://arxiv.org/abs/1812.00971)
Python
175
star
62

vampire

Variational Methods for Pretraining in Resource-limited Environments
Python
173
star
63

vila

Incorporating VIsual LAyout Structures for Scientific Text Classification
Python
172
star
64

s2-folks

Public space for the user community of Semantic Scholar APIs to share scripts, report issues, and make suggestions.
171
star
65

hidden-networks

Python
164
star
66

cord19

Get started with CORD-19
161
star
67

mmda

multimodal document analysis
Jupyter Notebook
158
star
68

PRIMER

The official code for PRIMERA: Pyramid-based Masked Sentence Pre-training for Multi-document Summarization
Python
150
star
69

catwalk

This project studies the performance and robustness of language models and task-adaptation methods.
Python
141
star
70

dnw

Discovering Neural Wirings (https://arxiv.org/abs/1906.00586)
Python
139
star
71

deepfigures-open

Companion code to the paper "Extracting Scientific Figures with Distantly Supervised Neural Networks" πŸ€–
Python
133
star
72

tpu_pretrain

LM Pretraining with PyTorch/TPU
Python
132
star
73

allentune

Hyperparameter Search for AllenNLP
Python
128
star
74

SciREX

Data/Code Repository for https://api.semanticscholar.org/CorpusID:218470122
Python
128
star
75

scidocs

Dataset accompanying the SPECTER model
Python
127
star
76

lm-explorer

interactive explorer for language models
Python
127
star
77

pdffigures

Command line tool to extract figures, tables, and captions from scholarly documents in PDF form.
C++
125
star
78

OpenBookQA

Code for experiments on OpenBookQA from the EMNLP 2018 paper "Can a Suit of Armor Conduct Electricity? A New Dataset for Open Book Question Answering"
Python
121
star
79

peS2o

Pretraining Efficiently on S2ORC!
120
star
80

gooaq

Question-answers, collected from Google
Python
116
star
81

allennlp-as-a-library-example

A simple example for how to build your own model using AllenNLP as a dependency.
Python
113
star
82

embodied-clip

Official codebase for EmbCLIP
Python
111
star
83

multimodalqa

Python
109
star
84

alexafsm

With alexafsm, developers can model dialog agents with first-class concepts such as states, attributes, transition, and actions. alexafsm also provides visualization and other tools to help understand, test, debug, and maintain complex FSM conversations.
Python
108
star
85

allennlp-semparse

A framework for building semantic parsers (including neural module networks) with AllenNLP, built by the authors of AllenNLP
Python
107
star
86

scicite

Repository for NAACL 2019 paper on Citation Intent prediction
Python
106
star
87

ai2thor-rearrangement

πŸ”€ Visual Room Rearrangement
Python
104
star
88

commonsense-kg-completion

Python
102
star
89

medicat

Dataset of medical images, captions, subfigure-subcaption annotations, and inline textual references
Python
102
star
90

real-toxicity-prompts

Jupyter Notebook
101
star
91

s2search

The Semantic Scholar Search Reranker
Python
99
star
92

aristo-mini

Aristo mini is a light-weight question answering system that can quickly evaluate Aristo science questions with an evaluation web server and the provided baseline solvers.
Python
96
star
93

gpv-1

A task-agnostic vision-language architecture as a step towards General Purpose Vision
Jupyter Notebook
92
star
94

flex

Few-shot NLP benchmark for unified, rigorous eval
Python
91
star
95

elastic

Python
91
star
96

manipulathor

ManipulaTHOR, a framework that facilitates visual manipulation of objects using a robotic arm
Jupyter Notebook
88
star
97

spoc-robot-training

SPOC: Imitating Shortest Paths in Simulation Enables Effective Navigation and Manipulation in the Real World
Python
85
star
98

S2AND

Semantic Scholar's Author Disambiguation Algorithm & Evaluation Suite
Python
85
star
99

propara

ProPara (Process Paragraph Comprehension) dataset and models
Python
82
star
100

ARC-Solvers

ARC Question Solvers
Python
82
star