• Stars
    star
    1,137
  • Rank 40,971 (Top 0.9 %)
  • Language
    Python
  • License
    Apache License 2.0
  • Created 9 months ago
  • Updated about 1 month ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

ReFT: Representation Finetuning for Language Models

pyreft by pyvene

State-of-the-art Representation Fine-Tuning (ReFT) methods

A Powerful, Parameter-Efficient, and Interpretable fine-tuning method

Want to try a fine-tuning method that uses a fraction of the parameter count of SoTA PEFTs, while achieving potentially better performance? Introducing pyreft, a representation fine-tuning (ReFT) library that supports adapting internal language model representations via trainable interventions. With fewer fine-tuning parameters and more robust performance, pyreft can boost fine-tuning efficiency, decrease fine-tuning cost, while opening the doors to study the interpretability of adapting parameters.

pyreft supports

  • Fine tuning any pretrained LMs on HuggingFace with ReFT
  • Setting ReFT hyperparameters via configs
  • Sharing the fine-tuned results easily to HuggingFace

Tip

Powerful and Parameter-Efficient: Read Our ReFT paper for an introduction of representation fine-tuning (ReFT) and its performance.

Tip

Intepretable Finetuning: Read Composable ReFT for a sneak-peek of the interpretable nature of ReFT.

Quickstart

Here is one verified conda env setup steps:

conda create --name awesome-reft python=3.10
conda activate awesome-reft

Then, install pyreft from pip+git:

pip install git+https://github.com/stanfordnlp/pyreft.git

Or install pyreft from pip (coming soon):

pip install pyreft

Prepare a model for training with a ReFT method by wrapping the base model and ReFT configuration with get_reft_model. In the following example, we are using ConsreftIntervention (Constant LoReFT Intervention) which is simpler than the original LoReFT described in the paper:

import torch
import transformers

from pyreft import (
    get_reft_model,
    ReftConfig,
    ConsreftIntervention
)

# loading huggingface model
model_name_or_path = "yahma/llama-7b-hf"
model = transformers.AutoModelForCausalLM.from_pretrained(
    model_name_or_path, torch_dtype=torch.bfloat16, device_map="cuda")

# wrap the model with rank-1 constant reft
reft_config = ReftConfig(representations={"layer": 15, "component": "block_output",
    "intervention": ConsreftIntervention(
    embed_dim=model.config.hidden_size, low_rank_dimension=1)})
reft_model = get_reft_model(model, reft_config)
reft_model.print_trainable_parameters()

"trainable intervention params: 4,097 || trainable model params: 0"
"model params: 6,738,415,616 || trainable%: 6.080064266549391e-05"

With this config, yo are tuning 0.00006% parameters, and 4,097 to be exact. Then, the reft_model can be used for any downstream tasks. We can train a rank-1 ReFT to make the model produce some constant output:

from pyreft import (
    ReftTrainerForCausalLM,
    make_last_position_supervised_data_module
)
tokenizer = transformers.AutoTokenizer.from_pretrained(
    model_name_or_path, model_max_length=2048, padding_side="right", use_fast=False)
tokenizer.pad_token = tokenizer.unk_token

# get training data to train our intervention to remember the following sequence
memo_sequence = """
Welcome to the Natural Language Processing Group at Stanford University!
We are a passionate, inclusive group of students and faculty, postdocs
and research engineers, who work together on algorithms that allow computers
to process, generate, and understand human languages. Our interests are very
broad, including basic scientific research on computational linguistics,
machine learning, practical applications of human language technology,
and interdisciplinary work in computational social science and cognitive
science. We also develop a wide variety of educational materials
on NLP and many tools for the community to use, including the Stanza
toolkit which processes text in over 60 human languages.
"""
data_module = make_last_position_supervised_data_module(
    tokenizer=tokenizer,
    model=model,
    inputs=["GO->"],
    outputs=[memo_sequence])

# train
training_args = transformers.TrainingArguments(
    num_train_epochs=1000.0,
    output_dir="./tmp",
    learning_rate=2e-3,
    logging_steps=50)
trainer = ReftTrainerForCausalLM(
    model=reft_model, tokenizer=tokenizer,
    args=training_args, **data_module)
_ = trainer.train()

Once you are done with your training, you can check your model generations:

prompt = tokenizer("GO->", return_tensors="pt").to("cuda")
base_unit_location = prompt["input_ids"].shape[-1] - 1  # last position
_, reft_response = reft_model.generate(
    prompt, unit_locations={"sources->base": (None, [[[base_unit_location]]])},
    intervene_on_prompt=True, max_new_tokens=512, do_sample=False, 
    eos_token_id=tokenizer.eos_token_id, early_stopping=True
)
print(tokenizer.decode(reft_response[0], skip_special_tokens=True))

"""GO->
Welcome to the Natural Language Processing Group at Stanford University!
We are a passionate, inclusive group of students and faculty, postdocs
and research engineers, who work together on algorithms that allow computers
to process, generate, and understand human languages. Our interests are very
broad, including basic scientific research on computational linguistics,
machine learning, practical applications of human language technology,
and interdisciplinary work in computational social science and cognitive
science. We also develop a wide variety of educational materials
on NLP and many tools for the community to use, including the Stanza
toolkit which processes text in over 60 human languages."""

We successfully compress the text into 4,097 parameters! We perform more rigious memorisation tests like this one in ReFT Interp.

You can do ReFT with any language modeling tasks or SFT. Check out our examples folder! You can train a 7B chat-model close to ChatGPT-3.5-1103 (81.9 v.s. 86.3 Alpaca-eval scores) under 18 mins with a single A100 GPU + ReFT by following steps in train.py training Llama-2 with the Ultrafeedback dataset.

Loading our 18 min-cooked Loreft1k-Llama-2-7b-hf from HuggingFace

For full tutorial, please take a look at chat_model.ipynb.

Loading the base LM first:

import torch, transformers
from pyreft import (
    ReftModel,
    get_intervention_locations
)

prompt_no_input_template = """Below is an instruction that \
describes a task. Write a response that appropriately \
completes the request.

### Instruction:
%s

### Response:
"""

device = "cuda" if torch.cuda.is_available() else "cpu"

model_name_or_path = "meta-llama/Llama-2-7b-hf"
reft_model_name_or_path = "zhengxuanzenwu/Loreft1k-Llama-2-7b-hf"
tokenizer = transformers.AutoTokenizer.from_pretrained(
    model_name_or_path, model_max_length=2048, padding_side="right", use_fast=False)
tokenizer.pad_token = tokenizer.unk_token

model = transformers.AutoModelForCausalLM.from_pretrained(
    model_name_or_path, torch_dtype=torch.bfloat16, device_map=device)

Then, loading ReFT artifacts:

reft_model = ReftModel.load(
    "zhengxuanzenwu/Loreft1k-Llama-2-7b-hf", model, from_huggingface_hub=True)
reft_model.set_device(device)

Start chatting with it:

instruction = "Tell me about the NLP Group at Stanford University."

# tokenize and prepare the input
prompt = prompt_no_input_template % instruction
prompt = tokenizer(prompt, return_tensors="pt").to(device)
intervention_locations = torch.tensor([get_intervention_locations(
    last_position=prompt["input_ids"].shape[-1], positions="f5+l5",
    num_interventions=len(reft_model.interventions))]).permute(1, 0, 2).tolist()

# generate
_, reft_response = reft_model.generate(
    prompt, 
    unit_locations={"sources->base": (None, intervention_locations)},
    intervene_on_prompt=True, max_new_tokens=512, do_sample=False, 
    no_repeat_ngram_size=5, repetition_penalty=1.1,
    eos_token_id=tokenizer.eos_token_id, early_stopping=True
)
print(tokenizer.decode(reft_response[0], skip_special_tokens=True))

Note that Llama-2 models can follow instructions zero-shot. We encourge people to try on other more primitive base LMs and see if ReFT can work well!

Usage and License Notices: Our chat-model is intended and licensed for research use only. The model is CC BY NC 4.0 (allowing only non-commercial use) should not be used outside of research purposes.

Why should you use ReFT instead of PEFTs?

There are various benefits such as saving memory and storage. In addition to that, ReFT is more interpretable and extensible than PEFT. The interventions we are learning are simply a causal abstraction of the training task, without modifying any model weights. The intervention site search space is large, and can be at any set of token positions which is more flexible.

We showcase ReFT performance on various benchmarks against popular PEFTs such as LoRA and its newer variants (e.g., DoRA) in our paper.

Learn more through examples

Example Description
pyvene The backbone of pyreft library
LoReFT Reproduce our ReFT paper main results
Alpaca Instruction-tune LMs with ReFT
ReFT Interp Some hints on why ReFT works
Composable ReFT Some why ReFT is an interpretable method

Citation

Make sure you cite the ReFT paper:

@article{wuandarora2024reft,
  title={{ReFT}: Representation Finetuning for Language Models},
  author={Wu, Zhengxuan and Arora, Aryaman and Wang, Zheng and Geiger, Atticus and Jurafsky, Dan and Manning, Christopher D. and Potts, Christopher},
  booktitle={arXiv:2404.03592},
  url={arxiv.org/abs/2404.03592},
  year={2024}
}

And please cite the pyvene library paper as well:

@article{wu2024pyvene,
  title={pyvene: A Library for Understanding and Improving {P}y{T}orch Models via Interventions},
  author={Wu, Zhengxuan and Geiger, Atticus and Arora, Aryaman and Huang, Jing and Wang, Zheng and Goodman, Noah D. and Manning, Christopher D. and Potts, Christopher},
  booktitle={arXiv:2403.07809},
  url={arxiv.org/abs/2403.07809},
  year={2024}
}

Outreach

If you are interested in integrating this library into your workflow or in reimplementing it for improved efficiency, please feel free to contact us! We may have additional insights to share.

Star History

Star History Chart

More Repositories

1

dspy

DSPy: The framework for programming—not prompting—foundation models
Python
18,220
star
2

CoreNLP

CoreNLP: A Java suite of core NLP tools for tokenization, sentence segmentation, NER, parsing, coreference, sentiment analysis, etc.
Java
9,678
star
3

stanza

Stanford NLP Python library for tokenization, sentence segmentation, NER, and parsing of many human languages
Python
7,278
star
4

GloVe

Software in C and data files for the popular GloVe model for distributed word representations, a.k.a. word vectors or embeddings
C
6,867
star
5

cs224n-winter17-notes

Course notes for CS224N Winter17
TeX
1,587
star
6

treelstm

Tree-structured Long Short-Term Memory networks (http://arxiv.org/abs/1503.00075)
Lua
875
star
7

pyvene

Stanford NLP Python Library for Understanding and Improving PyTorch Models via Interventions
Python
625
star
8

string2string

String-to-String Algorithms for Natural Language Processing
Jupyter Notebook
533
star
9

python-stanford-corenlp

Python interface to CoreNLP using a bidirectional server-client interface.
Python
516
star
10

mac-network

Implementation for the paper "Compositional Attention Networks for Machine Reasoning" (Hudson and Manning, ICLR 2018)
Python
494
star
11

phrasal

A large-scale statistical machine translation system written in Java.
Java
208
star
12

spinn

SPINN (Stack-augmented Parser-Interpreter Neural Network): fast, batchable, context-aware TreeRNNs
Python
205
star
13

coqa-baselines

The baselines used in the CoQA paper
Python
176
star
14

cocoa

Framework for learning dialogue agents in a two-player game setting.
Python
158
star
15

stanza-old

Stanford NLP group's shared Python tools.
Python
138
star
16

chirpycardinal

Stanford's Alexa Prize socialbot
Python
131
star
17

stanfordnlp

[Deprecated] This library has been renamed to "Stanza". Latest development at: https://github.com/stanfordnlp/stanza
Python
114
star
18

wge

Workflow-Guided Exploration: sample-efficient RL agent for web tasks
Python
109
star
19

pdf-struct

Logical structure analysis for visually structured documents
Python
81
star
20

edu-convokit

Edu-ConvoKit: An Open-Source Framework for Education Conversation Data
Jupyter Notebook
75
star
21

cs224n-web

http://cs224n.stanford.edu
HTML
60
star
22

ColBERT-QA

Code for Relevance-guided Supervision for OpenQA with ColBERT (TACL'21)
40
star
23

stanza-train

Model training tutorials for the Stanza Python NLP Library
Python
37
star
24

phrasenode

Mapping natural language commands to web elements
Python
37
star
25

contract-nli-bert

A baseline system for ContractNLI (https://stanfordnlp.github.io/contract-nli/)
Python
29
star
26

color-describer

Code for Learning to Generate Compositional Color Descriptions
OpenEdge ABL
26
star
27

stanza-resources

23
star
28

python-corenlp-protobuf

Python bindings for Stanford CoreNLP's protobufs.
Python
20
star
29

miniwob-plusplus-demos

Demos for the MiniWoB++ benchmark
17
star
30

multi-distribution-retrieval

Code for our paper Resources and Evaluations for Multi-Distribution Dense Information Retrieval
Python
14
star
31

huggingface-models

Scripts for pushing models to huggingface repos
Python
11
star
32

nlp-meetup-demo

Java
8
star
33

sentiment-treebank

Updated version of SST
Python
8
star
34

en-worldwide-newswire

An English NER dataset built from foreign newswire
Python
7
star
35

plot-data

datasets for plotting
Jupyter Notebook
6
star
36

contract-nli

ContractNLI: A Dataset for Document-level Natural Language Inference for Contracts
HTML
4
star
37

plot-interface

Web interface for the plotting project
JavaScript
3
star
38

handparsed-treebank

Extra hand parsed data for training models
Perl
2
star
39

coqa

CoQA -- A Conversational Question Answering Challenge
Shell
2
star
40

pdf-struct-models

A repository for hosting models for https://github.com/stanfordnlp/pdf-struct
HTML
2
star
41

chirpy-parlai-blenderbot-fork

A fork of ParlAI supporting Chirpy Cardinal's custom neural generator
Python
2
star
42

wob-data

Data for QAWoB and FlightWoB web interaction benchmarks from the World of Bits paper (Shi et al., 2017).
Python
2
star
43

pdf-struct-dataset

Dataset for pdf-struct (https://github.com/stanfordnlp/pdf-struct)
HTML
1
star
44

nn-depparser

A re-implementation of nndep using PyTorch.
Python
1
star