• Stars
    star
    537
  • Rank 82,649 (Top 2 %)
  • Language
    Python
  • License
    MIT License
  • Created over 3 years ago
  • Updated 6 months ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Implementation of ChatGPT RLHF (Reinforcement Learning with Human Feedback) on any generation model in huggingface's transformer (blommz-176B/bloom/gpt/bart/T5/MetaICL)

TextRL: Text Generation with Reinforcement Learning

PyPI Download Last Commit CodeFactor Visitor

TextRL is a Python library that aims to improve text generation using reinforcement learning, building upon Hugging Face's Transformers, PFRL, and OpenAI GYM. TextRL is designed to be easily customizable and can be applied to various text-generation models.

TextRL

Table of Contents

Introduction

TextRL utilizes reinforcement learning to fine-tune text generation models. It is built upon the following libraries:

Example - gpt2

CLICK ME

GPT2 Example

import pfrl
from textrl import TextRLEnv, TextRLActor, train_agent_with_evaluation
from transformers import AutoModelForCausalLM, AutoTokenizer
import logging
import sys

logging.basicConfig(level=logging.INFO, stream=sys.stdout, format='')

checkpoint = "gpt2"

tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype="auto", device_map="auto")

model = model.cuda()


class MyRLEnv(TextRLEnv):
    def get_reward(self, input_item, predicted_list, finish):  # predicted will be the list of predicted token
        reward = [0]
        if finish:
            reward = [1]  # calculate reward score base on predicted_list
        return reward


observaton_list = [{"input":"explain how attention work in seq2seq model"}]
env = TextRLEnv(model, tokenizer, observation_input=observaton_list, max_length=20, compare_sample=2)
actor = TextRLActor(env, model, tokenizer,
                    act_deterministically=False,
                    temperature=1.0,
                    top_k=0,
                    top_p=1.0,
                    repetition_penalty=2)
agent = actor.agent_ppo(update_interval=2, minibatch_size=2, epochs=10)
print(actor.predict(observaton_list[0]))

train_agent_with_evaluation(
    agent,
    env,
    steps=100,
    eval_n_steps=None,
    eval_n_episodes=1,
    eval_interval=2,
    outdir='bloomโ€”test',
)

print(actor.predict(observaton_list[0]))

Example - flan-t5

CLICK ME

Example Code

colab example: google/flan-t5-base

import pfrl
from textrl import TextRLEnv, TextRLActor, train_agent_with_evaluation
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import logging
import sys

logging.basicConfig(level=logging.INFO, stream=sys.stdout, format='')


tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base")  
model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-base")
model.eval()
model.cuda()

sentiment = pipeline('sentiment-analysis',model="cardiffnlp/twitter-roberta-base-sentiment",tokenizer="cardiffnlp/twitter-roberta-base-sentiment",device=0,return_all_scores=True)

class MyRLEnv(TextRLEnv):
    def get_reward(self, input_item, predicted_list, finish): # predicted will be the list of predicted token
      reward = 0
      if finish or len(predicted_list[0]) >= self.env_max_length:
        predicted_text = tokenizer.convert_tokens_to_string(predicted_list[0])
        # sentiment classifier
        reward = sentiment(input_item['input']+predicted_text)[0][0]['score'] * 10
      return reward

observaton_list = [{'input':'i think dogecoin is'}]
env = MyRLEnv(model, tokenizer, observation_input=observaton_list, compare_sample=1)
actor = TextRLActor(env,model,tokenizer,optimizer='adamw',
                    temperature=0.8,
                    top_k=100,
                    top_p=0.85,)
agent = actor.agent_ppo(update_interval=50, minibatch_size=3, epochs=10,lr=3e-4)
print(actor.predict(observaton_list[0]))

pfrl.experiments.train_agent_with_evaluation(
    agent,
    env,
    steps=3000,
    eval_n_steps=None,
    eval_n_episodes=1,       
    train_max_episode_len=100,  
    eval_interval=10,
    outdir='checkpoint', 
)
agent.load("./checkpoint/best")
print(actor.predict(observaton_list[0]))

Example - bigscience/bloomz-7b1-mt

CLICK ME

bloomz-7b1-mt Example

import pfrl
from textrl import TextRLEnv, TextRLActor, train_agent_with_evaluation
from transformers import AutoModelForCausalLM, AutoTokenizer
import logging
import sys

logging.basicConfig(level=logging.INFO, stream=sys.stdout, format='')

checkpoint = "bigscience/bloomz-7b1-mt"

tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype="auto", device_map="auto")

model = model.cuda()


class MyRLEnv(TextRLEnv):
    def get_reward(self, input_item, predicted_list, finish):  # predicted will be the list of predicted token
        reward = [0]
        if finish:
            reward = [1]  # calculate reward score base on predicted_list
        return reward


observaton_list = [{"input":"explain how attention work in seq2seq model"}]
env = TextRLEnv(model, tokenizer, observation_input=observaton_list, max_length=20, compare_sample=2)
actor = TextRLActor(env, model, tokenizer,
                    act_deterministically=False,
                    temperature=1.0,
                    top_k=0,
                    top_p=1.0,
                    repetition_penalty=2)
agent = actor.agent_ppo(update_interval=2, minibatch_size=2, epochs=10)
print(actor.predict(observaton_list[0]))

train_agent_with_evaluation(
    agent,
    env,
    steps=100,
    eval_n_steps=None,
    eval_n_episodes=1,
    eval_interval=2,
    outdir='bloomโ€”test',
)

print(actor.predict(observaton_list[0]))

Example - 176B BLOOM

CLICK ME

bloomz-176B Example

Strongly recommend contribute on public swarm to increase petals capacity

https://github.com/bigscience-workshop/petals

install pip install petals -U first

import pfrl
from textrl import TextRLEnv, TextRLActor, train_agent_with_evaluation
from transformers import BloomTokenizerFast
from petals import DistributedBloomForCausalLM
import logging
import sys

logging.basicConfig(level=logging.INFO, stream=sys.stdout, format='')

MODEL_NAME = "bigscience/bloom-petals"
tokenizer = BloomTokenizerFast.from_pretrained(MODEL_NAME)
model = DistributedBloomForCausalLM.from_pretrained(MODEL_NAME)
model = model.cuda()


class MyRLEnv(TextRLEnv):
    def get_reward(self, input_item, predicted_list, finish):  # predicted will be the list of predicted token
        reward = [0]
        if finish:
            reward = [1]  # calculate reward score base on predicted_list
        return reward


observaton_list = [{"input":"explain how attention work in seq2seq model"}]
env = TextRLEnv(model, tokenizer, observation_input=observaton_list, max_length=20, compare_sample=2)
actor = TextRLActor(env, model, tokenizer,
                    act_deterministically=False,
                    temperature=1.0,
                    top_k=0,
                    top_p=1.0,
                    repetition_penalty=2)
agent = actor.agent_ppo(update_interval=2, minibatch_size=2, epochs=10)

print(actor.predict(observaton_list[0]))

train_agent_with_evaluation(
    agent,
    env,
    steps=100,
    eval_n_steps=None,
    eval_n_episodes=1,
    eval_interval=2,
    outdir='bloomโ€”test',
)

print(actor.predict(observaton_list[0]))

Example - Controllable generation via RL to let Elon Musk speak ill of DOGE

CLICK ME

[Controllable generation via RL to let Elon Musk speak ill of DOGE ](https://github.com/voidful/TextRL/blob/main/example/2022-12-10-textrl-elon-musk.ipynb)

colab example: bigscience/bloom-560m

colab exmaple: huggingtweets/elonmusk

before: i think dogecoin is a great idea.
after: i think dogecoin is a great idea, but I think it is a little overused.

Installation

pip install

pip install pfrl@git+https://github.com/voidful/pfrl.git
pip install textrl

Build from source

git clone and cd into this project.

pip install -e .

Usage

Initialize agent and environment

import torch
from textrl import TextRLEnv, TextRLActor, train_agent_with_evaluation
from transformers import AutoModelForCausalLM, AutoTokenizer

checkpoint = "bigscience/bloomz-7b1-mt"

tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype="auto", device_map="auto")

model = model.cuda()

Set up reward function for environment

  • predicted(list[str]): will be the list of predicted tokens
  • finish(bool): whether the end of sentence has been reached or not
class MyRLEnv(TextRLEnv):
    def get_reward(self, input_item, predicted_list, finish):
        if finish:
            reward = [0]  # calculate reward score based on predicted_list
        return reward

Prepare for training

  • observation_list should be a list of all possible input strings for model training

    Example: observation_list = [{"input":'testing sent 1'},{"input":'testing sent 2'}]

env = MyRLEnv(model, tokenizer, observation_input=observation_list)
actor = TextRLActor(env, model, tokenizer)
agent = actor.agent_ppo(update_interval=10, minibatch_size=2000, epochs=20)

Train

n_episodes = 1000
max_episode_len = 200  # max sentence length

for i in range(1, n_episodes + 1):
    obs = env.reset()
    R = 0
    t = 0
    while True:
        action = agent.act(obs)
        obs, reward, done, pred = env.step(action)
        R += reward
        t += 1
        reset = t == max_episode_len
        agent.observe(obs, reward, done, reset)
        if done or reset:
            break
    if i % 10 == 0:
        print('episode:', i, 'R:', R)
    if i % 50 == 0:
        print('statistics:', agent.get_statistics())
print('Finished.')

Another way to train:

import logging
import sys

logging.basicConfig(level=logging.INFO, stream=sys.stdout, format='')

train_agent_with_evaluation(
    agent,
    env,
    steps=1000,
    eval_n_steps=None,
    eval_n_episodes=1500,
    train_max_episode_len=50,
    eval_interval=10000,
    outdir='somewhere',
)

Prediction

agent.load("somewhere/best")  # loading the best model
actor.predict("input text")

This updated usage section provides a comprehensive guide on how to initialize the agent and environment, set up the reward function for the environment, prepare for training, train the model, and make predictions. It also includes an alternative way to train the model using the train_agent_with_evaluation function.

Dump trained model to huggingface's model

textrl-dump --model ./model_path_before_rl --rl ./rl_path --dump ./output_dir

Key Parameters for RL Training

To finetune a language model using RL, you need to modify the reward function:

from textrl import TextRLEnv

class MyRLEnv(TextRLEnv):
    def get_reward(self, input_item, predicted_list, finish):
        # input_item is the prompt input for the model, it will be one of your observation
        # an observation will be a list of sentence of eg: ['inputted sentence','xxx','yyy']
        # only the first input will feed to the model 'inputted sentence', and 
        # the remaining can be the reference for reward calculation

        # predicted_list is the list of predicted sentences of RL model generated,
        # it will be used for ranking reward calculation

        # finish is the end of sentences flags, get_reward will be called during generating each word, and 
        # when finish is True, it means the sentence is finished, it will use for sentence level reward calculation.

        # reward should be the list equal to the length of predicted_list
        return reward

Parameters for sampling diverse examples:

actor = TextRLActor(env, model, tokenizer,
                    act_deterministically=False,  # select the max probability token for each step or not
                    temperature=1,                # temperature for sampling
                    compare_sample=2,             # num of sample to rank
                    top_k=0,                      # top k sampling
                    top_p=1.0,                    # top p sampling
                    repetition_penalty=2)         # repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858)

When training a reinforcement learning (RL) model, several key parameters need to be tuned to ensure optimal performance. Here is a list of important parameters and their descriptions:

  1. Update Interval: This determines how often the RL agent updates its policy based on collected experiences. A smaller update interval means the agent learns more frequently from recent experiences, while a larger interval allows more experiences to accumulate before learning. In the example above, the update interval is set to 10.
update_interval=10
  1. Minibatch Size: The number of experiences sampled from the experience replay buffer to compute the gradient update. A larger minibatch size helps to stabilize learning and reduce variance, but at the cost of increased computational requirements.
minibatch_size=2000
  1. Epochs: The number of times the agent iterates through the entire minibatch to update its policy. More epochs can lead to better learning but may increase the risk of overfitting.
epochs=20
  1. Discount Factor (Gamma): This parameter determines how much future rewards are discounted when calculating the expected return. A value closer to 1 makes the agent more farsighted, while a value closer to 0 makes the agent more focused on immediate rewards.
gamma=0.99
  1. Learning Rate: The step size used for updating the policy. A larger learning rate allows for faster convergence but may lead to instability in learning, while a smaller learning rate ensures stable learning at the cost of slower convergence.
lr=1e-4
  1. Epsilon: A parameter used in the PPO algorithm to clip the policy ratio. This helps to control the magnitude of policy updates, preventing excessively large updates that can destabilize learning.
epsilon=0.2
  1. Entropy Coefficient: This parameter encourages exploration by adding a bonus reward for taking less certain actions. A higher entropy coefficient promotes more exploration, while a lower coefficient focuses the agent on exploiting known strategies.
entropy_coef=0.01
  1. Training Steps: The total number of steps the agent takes during training. More steps typically lead to better learning but may require more computational time.
steps=1000
  1. Evaluation Interval: The number of training steps between evaluations. Increasing the evaluation interval reduces the computational time spent on evaluation, but it may also reduce the frequency at which you can monitor the agent's progress.
eval_interval=10000
  1. Max Episode Length: The maximum number of steps allowed in a single episode during training. This can prevent the agent from getting stuck in long, unproductive episodes.
train_max_episode_len=50

These parameters need to be carefully tuned based on the specific problem and environment to achieve the best performance. It is generally recommended to start with default values and then adjust them based on the observed learning behavior.

More Repositories

1

awesome-chatgpt-dataset

Unlock the Power of LLM: Explore These Datasets to Train Your Own ChatGPT!
Python
688
star
2

Codec-SUPERB

Audio Codec Speech processing Universal PERformance Benchmark
Python
201
star
3

tw_stocker

keep tracking and store taiwan stock information
Python
100
star
4

TFkit

๐Ÿค–๐Ÿ“‡ handling multiple nlp task in one pipeline
Python
56
star
5

SpeechMix

Explore different way to mix speech model(wav2vec2, hubert) and nlp model(BART,T5,GPT) together
Python
41
star
6

vall-e-encodec

Python
41
star
7

BertGenerate

Fine tuning bert for text generation
Jupyter Notebook
38
star
8

asr-trainer

one script for xls-r/xlsr/whisper fine-tuning
Python
37
star
9

aidev

Revolutionize your development workflow with AI-powered code assistance, automating mock tests, suggestions, and unit test generation in a single Python CLI tool.
Python
35
star
10

NLPrep

๐Ÿณ NLPrep - dataset tool for many natural language processing task
Python
28
star
11

BDG

Code for "A BERT-based Distractor Generation Scheme with Multi-tasking and Negative Answer Training Strategies."
Python
27
star
12

Phraseg

Phraseg - ไธ€่จ€๏ผšๆ–ฐ่ฉž็™ผ็พๅทฅๅ…ทๅŒ…
Jupyter Notebook
26
star
13

wav2vec2-xlsr-multilingual-56

56 language, 1 model Multilingual ASR
Python
23
star
14

FTA

Technical Analysis on Cryptocurrency
Python
23
star
15

ChineseErrorDataset

CGED & CSC
22
star
16

asrp

ASR text preprocessing utility
Python
20
star
17

nlp2go

๐Ÿƒ hosting nlp models in one line
CSS
20
star
18

ipa2

Tools for convert Text to IPA in python
Python
16
star
19

nlp2

โš™๏ธTool for NLP - handle file and text
Python
15
star
20

awesome-question-answering-dataset

A list of awesome machine question answering dataset - ๆฉŸๅ™จๅ•็ญ”ๆ•ธๆ“š้›†
15
star
21

pretrain_bart

training BART from scratch
Python
12
star
22

SnapShare

Linking Your Phone To Computer Browser With Socket.io.
JavaScript
10
star
23

causal-lm-trainer

Python
8
star
24

wav2vec-u-exp

Build and Run Wav2vec Unsupervised Experiment
Dockerfile
8
star
25

whisper-live-asr-demo

run whisper on CPU/GPU server
JavaScript
8
star
26

gpu-info-api

๐Ÿฑโ€๐Ÿ’ป GPU Info API is an API that provides detailed information about Nvidia, AMD, and Intel GPUs. The information is extracted from Wikipedia and stored in JSON format.
Python
8
star
27

t5lephone

phoneme byt5
Python
7
star
28

MMLM

Toward Multi Modality Language Model - implementation of GPT-4o/Project Astra
Python
7
star
29

llm-estimator

Effortlessly predict training time, loss, and cost for LLM model training
JavaScript
6
star
30

WikiExtractor

Extract Knowledge from wiki dump file
Python
6
star
31

react-media-viewer

Ready to go Media Player Component for React.
JavaScript
6
star
32

dtokenizer

discretize everything into tokens
Python
6
star
33

hubert-cluster-code

Extract clustering feature from hubert
5
star
34

pytorch-tta

Pytorch implementation of "Fast and Accurate Deep Bidirectional Language Representations for Unsupervised Learning".
Python
5
star
35

GSQA

Generative Spoken Question Answering
Python
4
star
36

taiwan-company-network

ๅฐ็ฃๅ…ฌๅธๆŠ•่ณ‡้—œไฟ‚ๅœ–
CSS
4
star
37

DevLEGO

Create your development Env like LEGO blocks, run your projects on any device - be it a PC, Web, Phone or Tablet!
Shell
4
star
38

awesome-evaluation-lm

Collection Of Automated Language Model Assessment
3
star
39

fastpages

Jupyter Notebook
3
star
40

modelhub

3
star
41

Gossiping-Chinese-Positive-Corpus

PTT ๅ…ซๅฆ็‰ˆๅ•็ญ”-ๆญฃ้ข-ไธญๆ–‡่ชžๆ–™
3
star
42

survey-builder

survey builder for human evaluation
JavaScript
3
star
43

voidful

Python
3
star
44

audio-preprocessing-pipeline

Python
3
star
45

DG-Showcase

Showcase for "A BERT-based Distractor Generation Scheme with Multi-tasking and Negative Answer Training Strategies."
CSS
3
star
46

hubert-pretrain

using huggingface trainer to pre-train hubert
Python
2
star
47

dpr-multilingual

A multilingual version of DPR
2
star
48

telenotify

Python
2
star
49

tts-corpus-creator

collection of different source of TTS api for generating corpus.
Python
2
star
50

diff-aspect-set-dg

Python
2
star
51

depack

Extract files from any type of archive in command line
Python
2
star
52

Data2QA

Unified QA with different modality input
Python
2
star
53

bindtorchaudio

`bindtorchaudio` is a Python package that allows for easy installation of the `torchaudio` library, which provides audio processing functionalities for the PyTorch machine learning framework.
Python
2
star
54

seq2seq-lm-trainer

This is a simple example of using the T5 model for sequence-to-sequence tasks, leveraging Hugging Face's `Trainer` for efficient model training.
Python
2
star
55

PPA

Prompt Pool Agent
Python
2
star
56

bforce

bruteforce is all you need in a unstable system
Python
1
star
57

twcc-usage-slack-bot

TWCC GPU Usage Notification Slack Bot
Python
1
star
58

shows

lib for system monitoring with CPU/GPU/DISK/MEM/NET
Python
1
star
59

get-stat

lib for system monitoring in Python / Web API (CPU/GPU/DISK/MEM/NET/SERVICE)
1
star
60

NLPrep-Datasets

HTML
1
star
61

pearl

PEARL - Optimize Prompt Selection for Enhanced Answer Performance Using Reinforcement Learning
Python
1
star
62

uni-superb

Python
1
star
63

huggingface_notebook

Jupyter Notebook
1
star
64

superb-website

JavaScript
1
star
65

leverage-lm

small lm + RAG > LLM
1
star
66

fbcrawler

Python
1
star
67

SoundON

Python
1
star