• Stars
    star
    441
  • Rank 98,861 (Top 2 %)
  • Language
    Python
  • License
    MIT License
  • Created about 1 year ago
  • Updated 3 months ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Train and Infer Powerful Sentence Embeddings with AnglE | 🔥 SOTA on STS and MTEB Leaderboard

EN | 简体中文

AnglE📐: Angle-optimized Text Embeddings

It is Angle 📐, not Angel 👼.

🔥 A New SOTA for Semantic Textual Similarity!

🔥 Our universal sentence embedding WhereIsAI/UAE-Large-V1 achieves SOTA on the MTEB Leaderboard with an average score of 64.64!

https://arxiv.org/abs/2309.12871 PyPI version PyPI Downloads http://makeapullrequest.com

PWC PWC PWC PWC PWC PWC PWC

📊 Results on MTEB Leaderboard [click to expand]

📊 Results on STS benchmark [click to expand]

🤗 Pretrained Models

🤗 HF LoRA Weight Dependent Backbone LLM Language Prompt Pooling Strategy Examples
WhereIsAI/UAE-Large-V1 N N N EN Prompts.C for retrieval purposes, None for others cls Seach Demo
SeanLee97/angle-llama-13b-nli Y NousResearch/Llama-2-13b-hf Y EN Prompts.A last token /
SeanLee97/angle-llama-7b-nli-v2 Y NousResearch/Llama-2-7b-hf Y EN Prompts.A last token /
SeanLee97/angle-llama-7b-nli-20231027 Y NousResearch/Llama-2-7b-hf Y EN Prompts.A last token /
SeanLee97/angle-bert-base-uncased-nli-en-v1 N N N EN N cls_avg /
SeanLee97/angle-roberta-wwm-base-zhnli-v1 N N N ZH-CN N cls /
SeanLee97/angle-llama-7b-zhnli-v1 Y NousResearch/Llama-2-7b-hf Y ZH-CN Prompts.B last token /

💡 If the selected model is a LoRA weight, it must specify the corresponding dependent backbone.

📝 Training Details:

1) SeanLee97/angle-llama-7b-nli-20231027

We fine-tuned AnglE-LLaMA using 4 RTX 3090 Ti (24GB), the training script is as follows:

CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 --master_port=1234 train_angle.py \
--task NLI-STS --save_dir ckpts/NLI-STS-angle-llama-7b \
--w2 35 --learning_rate 2e-4 --maxlen 45 \
--lora_r 32 --lora_alpha 32 --lora_dropout 0.1 \
--save_steps 200 --batch_size 160 --seed 42 --do_eval 0 --load_kbit 4 --gradient_accumulation_steps 4 --epochs 1 

The evaluation script is as follows:

CUDA_VISIBLE_DEVICES=0,1 python eval.py \
    --load_kbit 16 \
    --model_name_or_path NousResearch/Llama-2-7b-hf \
    --lora_weight SeanLee97/angle-llama-7b-nli-20231027

Results

English STS Results

Model STS12 STS13 STS14 STS15 STS16 STSBenchmark SICKRelatedness Avg.
SeanLee97/angle-llama-7b-nli-20231027 78.68 90.58 85.49 89.56 86.91 88.92 81.18 85.90
SeanLee97/angle-llama-7b-nli-v2 79.00 90.56 85.79 89.43 87.00 88.97 80.94 85.96
SeanLee97/angle-llama-13b-nli 79.33 90.65 86.89 90.45 87.32 89.69 81.32 86.52
SeanLee97/angle-bert-base-uncased-nli-en-v1 75.09 85.56 80.66 86.44 82.47 85.16 81.23 82.37

Chinese STS Results

Model ATEC BQ LCQMC PAWSX STS-B SOHU-dd SOHU-dc Avg.
^shibing624/text2vec-bge-large-chinese 38.41 61.34 71.72 35.15 76.44 71.81 63.15 59.72
^shibing624/text2vec-base-chinese-paraphrase 44.89 63.58 74.24 40.90 78.93 76.70 63.30 63.08
SeanLee97/angle-roberta-wwm-base-zhnli-v1 49.49 72.47 78.33 59.13 77.14 72.36 60.53 67.06
SeanLee97/angle-llama-7b-zhnli-v1 50.44 71.95 78.90 56.57 81.11 68.11 52.02 65.59

^ denotes baselines, their results are retrieved from: https://github.com/shibing624/text2vec

Usage

AnglE supports two APIs, one is the transformers API, the other is the AnglE API. If you want to use the AnglE API, please install AnglE first:

python -m pip install -U angle-emb

UAE

  1. For Retrieval Purposes

For retrieval purposes, please use the prompt Prompts.C.

from angle_emb import AnglE, Prompts

angle = AnglE.from_pretrained('WhereIsAI/UAE-Large-V1', pooling_strategy='cls').cuda()
angle.set_prompt(prompt=Prompts.C)
vec = angle.encode({'text': 'hello world'}, to_numpy=True)
print(vec)
vecs = angle.encode([{'text': 'hello world1'}, {'text': 'hello world2'}], to_numpy=True)
print(vecs)
  1. For non-Retrieval Purposes
from angle_emb import AnglE

angle = AnglE.from_pretrained('WhereIsAI/UAE-Large-V1', pooling_strategy='cls').cuda()
vec = angle.encode('hello world', to_numpy=True)
print(vec)
vecs = angle.encode(['hello world1', 'hello world2'], to_numpy=True)
print(vecs)
Difference between retrieval and non-retrieval sentence embeddings. [click to expand]

In UAE, we use different approaches for retrieval and non-retrieval tasks, each serving a different purpose.

Retrieval tasks aim to find relevant documents, and as a result, the related documents may not have strict semantic similarities to each other.

For instance, when querying "How about ChatGPT?", the related documents are those that contain information related to "ChatGPT," such as "ChatGPT is amazing..." or "ChatGPT is bad....".

Conversely, non-retrieval tasks, such as semantic textual similarity, require sentences that are semantically similar.

For example, a sentence semantically similar to "How about ChatGPT?" could be "What is your opinion about ChatGPT?".

To distinguish between these two types of tasks, we use different prompts.

For retrieval tasks, we use the prompt "Represent this sentence for searching relevant passages: {text}" (Prompts.C in angle_emb).

For non-retrieval tasks, we set the prompt to empty, i.e., just input your text without specifying a prompt.

So, if your scenario is retrieval-related, it is highly recommended to set the prompt with angle.set_prompt(prompt=Prompts.C). If not, leave the prompt empty or use angle.set_prompt(prompt=None).

Angle-LLaMA

  1. AnglE
from angle_emb import AnglE, Prompts

angle = AnglE.from_pretrained('NousResearch/Llama-2-7b-hf', pretrained_lora_path='SeanLee97/angle-llama-7b-nli-v2')

print('All predefined prompts:', Prompts.list_prompts())
angle.set_prompt(prompt=Prompts.A)
print('prompt:', angle.prompt)
vec = angle.encode({'text': 'hello world'}, to_numpy=True)
print(vec)
vecs = angle.encode([{'text': 'hello world1'}, {'text': 'hello world2'}], to_numpy=True)
print(vecs)
  1. transformers
from angle_emb import AnglE
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel, PeftConfig

peft_model_id = 'SeanLee97/angle-llama-7b-nli-v2'
config = PeftConfig.from_pretrained(peft_model_id)
tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path).bfloat16().cuda()
model = PeftModel.from_pretrained(model, peft_model_id).cuda()

def decorate_text(text: str):
    return Prompts.A.format(text=text)

inputs = 'hello world!'
tok = tokenizer([decorate_text(inputs)], return_tensors='pt')
for k, v in tok.items():
    tok[k] = v.cuda()
vec = model(output_hidden_states=True, **tok).hidden_states[-1][:, -1].float().detach().cpu().numpy()
print(vec)

Angle-BERT

  1. AnglE
from angle_emb import AnglE

angle = AnglE.from_pretrained('SeanLee97/angle-bert-base-uncased-nli-en-v1', pooling_strategy='cls_avg').cuda()
vec = angle.encode('hello world', to_numpy=True)
print(vec)
vecs = angle.encode(['hello world1', 'hello world2'], to_numpy=True)
print(vecs)
  1. transformers
import torch
from transformers import AutoModel, AutoTokenizer

model_id = 'SeanLee97/angle-bert-base-uncased-nli-en-v1'
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModel.from_pretrained(model_id).cuda()

inputs = 'hello world!'
tok = tokenizer([inputs], return_tensors='pt')
for k, v in tok.items():
    tok[k] = v.cuda()
hidden_state = model(**tok).last_hidden_state
vec = (hidden_state[:, 0] + torch.mean(hidden_state, dim=1)) / 2.0
print(vec)

Train Custom AnglE Model

1. Train NLI

  1. Prepare your gpu environment

  2. Install python dependencies

python -m pip install -r requirements.txt
  1. Download data
  • Download multi_nli + snli:
$ cd data
$ sh download_data.sh
  • Download sts datasets
$ cd SentEval/data/downstream
$ bash download_dataset.sh

2. Custom Train

Open In Colab

from datasets import load_dataset
from angle_emb import AnglE, AngleDataTokenizer


# 1. load pretrained model
angle = AnglE.from_pretrained('SeanLee97/angle-bert-base-uncased-nli-en-v1', max_length=128, pooling_strategy='cls').cuda()

# 2. load dataset
# `text1`, `text2`, and `label` are three required columns.
ds = load_dataset('mteb/stsbenchmark-sts')
ds = ds.map(lambda obj: {"text1": str(obj["sentence1"]), "text2": str(obj['sentence2']), "label": obj['score']})
ds = ds.select_columns(["text1", "text2", "label"])

# 3. transform data
train_ds = ds['train'].shuffle().map(AngleDataTokenizer(angle.tokenizer, angle.max_length), num_proc=8)
valid_ds = ds['validation'].map(AngleDataTokenizer(angle.tokenizer, angle.max_length), num_proc=8)
test_ds = ds['test'].map(AngleDataTokenizer(angle.tokenizer, angle.max_length), num_proc=8)

# 4. fit
angle.fit(
    train_ds=train_ds,
    valid_ds=valid_ds,
    output_dir='ckpts/sts-b',
    batch_size=32,
    epochs=5,
    learning_rate=2e-5,
    save_steps=100,
    eval_steps=1000,
    warmup_steps=0,
    gradient_accumulation_steps=1,
    loss_kwargs={
        'w1': 1.0,
        'w2': 1.0,
        'w3': 1.0,
        'cosine_tau': 20,
        'ibn_tau': 20,
        'angle_tau': 1.0
    },
    fp16=True,
    logging_steps=100
)

# 5. evaluate
corrcoef, accuracy = angle.evaluate(test_ds, device=angle.device)
print('corrcoef:', corrcoef)

Citation

You are welcome to use our code and pre-trained models. If you use our code and pre-trained models, please support us by citing our work as follows:

@article{li2023angle,
  title={AnglE-optimized Text Embeddings},
  author={Li, Xianming and Li, Jing},
  journal={arXiv preprint arXiv:2309.12871},
  year={2023}
}

ChangeLogs

📅 Description
2023 Dec 4 Release a universal English sentence embedding model: WhereIsAI/UAE-Large-V1
2023 Nov 2 Release an English pretrained model: SeanLee97/angle-llama-13b-nli
2023 Oct 28 Release two chinese pretrained models: SeanLee97/angle-roberta-wwm-base-zhnli-v1 and SeanLee97/angle-llama-7b-zhnli-v1; Add chinese README.md

More Repositories

1

xmnlp

xmnlp:提供中文分词, 词性标注, 命名体识别,情感分析,文本纠错,文本转拼音,文本摘要,偏旁部首,句子表征及文本相似度计算等功能
Python
1,227
star
2

nlp_learning

结合python一起学习自然语言处理 (nlp): 语言模型、HMM、PCFG、Word2vec、完形填空式阅读理解任务、朴素贝叶斯分类器、TFIDF、PCA、SVD
Python
234
star
3

QANet_dureader

QANet+DuReader中文机器阅读理解
Python
223
star
4

TripleIE

依存句法实现关系三元组的自动抽取
Python
96
star
5

short-text-classification

SVM, FastText, TextCNN, BiGRU, CNN-BiGRU在短分本分类上的对比
Jupyter Notebook
84
star
6

datastruct_and_algorithms

python/c++实现常用算法(数据结构,搜索,排序,动态规划...)
Python
40
star
7

clfzoo

A deep text classifiers library.
Python
36
star
8

llano

Let ChatGPT (Large Language Models) Serve As Data Annotator and Zero-shot/few-shot Information Extractor.
Python
29
star
9

nnclf

神经网络分类器,PyTorch实现
Python
21
star
10

chinese_reading_comprehension

实现了Attention-over-Attention Neural Networks for Reading Comprehension
Python
20
star
11

generate-lyrics-using-PyTorch

use RNN to generate chinese lyrics
Python
16
star
12

simnet

基于numpy实现的简单神经网络框架
Python
15
star
13

duReader_pytorch

基于duReader的阅读理解
Python
9
star
14

titanic_disaster

数据分析实战 kaggle titanic disaster。使用了RandomForestRegressor来预测缺失值,RandomForestClassifier来分类
Jupyter Notebook
7
star
15

rnn-attention-classifier

tensorflow 实现RNN+Attention文本分类
Python
6
star
16

LanguageDetect.jl

Port of Google's language-detection library to Julia.
Julia
5
star
17

LL1-Parser

A simple LL1 compiler.
C++
4
star
18

datacleaner

datacleaner, python数据清洗
Python
4
star
19

xiaoming

a seq2seq + attention chatbot
Python
4
star
20

simple_svm

python 实现svm
Python
4
star
21

BertWordPieceTokenizer.jl

WordPiece Tokenizer for BERT models.
Julia
3
star
22

artf

a lightweight tensorflow library.
Python
3
star
23

bimpm

Implement the model of bimpm (Bilateral Multi-Perspective Matching for Natural Language Sentences)
Python
2
star
24

SeanLee97

Config files for my GitHub profile.
1
star
25

my_profile

1
star
26

module-weather

小模块: python 获取 http://weather.com.cn/ 的天气数据
Python
1
star
27

xmnlp-extend

xmnlp的扩充,支持基于DeepNLP的依存句法分析、命名实体识别等
1
star
28

word2vec-test

usage of word2vec
Python
1
star