• Stars
    star
    283
  • Rank 146,066 (Top 3 %)
  • Language
    Python
  • License
    Creative Commons ...
  • Created 11 months ago
  • Updated 4 months ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Painless Inference Acceleration (PIA)

A toolkit for LLM inference without ๐Ÿ˜ญ . Currently it contains our work LOOKAHEAD, a framework which accelerates LLM inference without loss of accuracy, other works will release soon.

News or Update ๐Ÿ”ฅ

  • [2024/01] We support all models of baichuan family (Baichuan-7b & 13b, Baichuan2-7b & 13b).

  • [2024/01] We fully support repetition_penalty parameter.

  • [2024/01] We support Mistral & Mixtral. example

  • [2023/12] We released our Lookahead paper on arXiv!

  • [2023/12] PIA released ๐Ÿ’ช !!! Fast, Faster, Fastest ๐Ÿ† !!!

Models we support

  • GLM
  • Baichuan & Baichuan 2
  • BLOOM
  • ChatGLM 2 & 3
  • GPT-2
  • GPT-J
  • InterLM
  • LLaMA & LLaMA-2
  • Mistral
  • Mixtral
  • OPT
  • Qwen

Known issuss & TODO

ISSUE 1. Repetition_penalty is not fully supported, we will fix it in the future.

ISSUE 2. Lookahead may generate responses different from original ones due to low-precise data type (i.e., fp16 or bf16), the responses would be the same with fp32.

ISSUE 3. Baichuan tokenizer cannot be initialized with the lastest version transformers (4.30.2 can work).

ISSUE 4. Qwen model may generate slightly different responses with lookahead when the repetition_penalty parameter is set.

TODO1: Support the latest version ๐Ÿค— transformers ]. Currently it's based on 4.30.2.

TODO2: Integrate our work FastCoT

TODO3: Optimize batch inference implementation with flash-attention.

Performance Comparison

Performance is measured by token/s(tokens per second) of generation tokens.

Public datasets and models

We use the first 1000 samples for evaluation and the rest for trie-tree cache construction. The hyper-parameters are decoding_length=64 and branch_lenght=8. The tag fused indicates operators are fused with triton, the implementation can be found in modeling_llama_batch.py.

model dataset GPU ๐Ÿค— transformers lookahead
Llama2-7b-chat Dolly-15k A100-80G 40.6 83.7 (x2.06)
Llama2-7b-chat(fused) Dolly-15k A100-80G 50.4 106.8 (x2.12)
Llama2-13b-chat Dolly-15k A100-80G 34.0 71.7 (x2.11)
Llama2-13b-chat(fused) Dolly-15k A100-80G 39.9 84.6 (x2.12)
ChatGLM2-6b Dolly-15k A100-80G 45.6 108.4 (x2.38)
Llama2-7b-chat GSM-8k A100-80G 41.4 111.3 (x2.69)
Llama2-7b-chat(fused) GSM-8k A100-80G 53.7 149.6 (x2.79)
Llama2-13b-chat GSM-8k A100-80G 31.2 71.1 (x2.28)
Llama2-13b-chat(fused) GSM-8k A100-80G 42.9 103.4 (x2.41)
ChatGLM2-6b GSM-8k A100-80G 43.3 94.0 (x2.17)

We test 5 examples with Llama2-7b-chat and dolly dataset, inference time without lookahead (the left figure) is 15.7s (48.2token/s), while inference time with lookahead is 6.4s (112.9token/s), speedup is 2.34.

Private datasets and models

We use the first 1000 samples for evaluation and the rest for trie-tree cache construction. The hyper-parameters are decoding_length=128 and branch_lenght=32.

Our method could obtain significant acceleration in RAG (Retrieval Augmented Generation) scenarios. However, there is no real-life datasets available currently. Therefore, we only evaluate on our private datasets and models. AntGLM-10B is a LLM developed by Ant Group with GLM architecture.

model scenarios GPU ๐Ÿค— transformers Lookahead
AntGLM-10b Citizen Biz Agent A100-80G 52.4 280.9(x5.36)
AntGLM-10b Enterprise Info QA A100-80G 50.7 259.1(x5.11)
AntGLM-10b Health Suggestion A100-80G 51.6 240.2(x4.66)

We test 5 examples with AntGLM-10B and AntRag dataset, inference time without lookahead (the left figure) is 16.9s (33.8token/s), while inference time with lookahead is 3.9s (147.6token/s), speedup is 4.37.

Introduction

Our repo PIA (short for Painless Inference Acceleration) is used for LLM inference, it is based on ๐Ÿค— transformers library.

  • It uses an on-the-fly trie-tree cache to prepare hierarchical multi-branch drafts, without the demand for assist models (e.g., speculative decoding) or additional head training (e.g., block decoding). With the efficient hierarchical structure, we can lookahead tens fo branches, therefore significantly improve generated tokens in a forward pass.

  • You can also benefit from our optimized fuesed operation kernels.

Note that our work is different from the other method named lookahead decoding.

Hierarchical multi-branch draft

flow

dynamic

Lincense ๏ผˆไฝฟ็”จๅ่ฎฎ๏ผ‰

ๅ่ฎฎไธบCC BY 4.0 (https://creativecommons.org/licenses/by/4.0/)

ไฝฟ็”จๆœฌ้กน็›ฎๅ‰๏ผŒ่ฏทๅ…ˆ้˜…่ฏปLICENSE.txtใ€‚ๅฆ‚ๆžœๆ‚จไธๅŒๆ„่ฏฅไฝฟ็”จๅ่ฎฎไธญๅˆ—ๅ‡บ็š„ๆกๆฌพใ€ๆณ•ๅพ‹ๅ…่ดฃๅฃฐๆ˜Žๅ’Œ่ฎธๅฏ๏ผŒๆ‚จๅฐ†ไธๅพ—ไฝฟ็”จๆœฌ้กน็›ฎไธญ็š„่ฟ™ไบ›ๅ†…ๅฎนใ€‚

Installation

  1. Clone this repository and navigate to PainlessInferenceAcceleration
git clone https://github.com/alipay/PainlessInferenceAcceleration.git
cd PainlessInferenceAcceleration
  1. Install Package
python setup.py install

Quick Start

Below is an example for the simplest use of lookahead to inference:

import torch
from transformers import AutoTokenizer


from pia.lookahead.common.lookahead_cache import LookaheadCache
from pia.lookahead.models.llama.modeling_llama import LlamaForCausalLM

model_dir = 'meta-llama/Llama-2-7b-chat-hf'
model = LlamaForCausalLM.from_pretrained(model_dir
                                         , cache_dir='./'
                                         , torch_dtype=torch.float16
                                         , low_cpu_mem_usage=True
                                         , device_map='auto'
                                         )
tokenizer = AutoTokenizer.from_pretrained(model_dir)

prompt = "Hello, I'm am conscious and"
inputs = tokenizer(prompt, return_tensors="pt")

output_ids = model.generate(input_ids=inputs.input_ids.cuda(),
                            attention_mask=inputs.attention_mask.cuda(),
                            max_new_tokens=256,
                            decoding_kwargs={'use_lookahead': True}
                            )
response = tokenizer.decode(output_ids[0].tolist())
print(f'{response=}')

To use lookahead with other models, we can run the scripts in the path examples/. Each supported models are included and can be used for correctness evaluation.

python [model name]_example.py

To evaluation speedup of lookahead, we can run the scripts in the path benchmarks/,

Customize Model

To support a customize model, usually we only need add a few lines, here is a example for supporting Llama:

from pia.lookahead.common.pretrained_model import LookaheadPreTrainedModel
class LlamaPreTrainedModel(LookaheadPreTrainedModel):
    '''
    other code
    '''

class LlamaModel(LlamaPreTrainedModel):

    '''
    other code
    '''

    def forward(
            self,
            input_ids: torch.LongTensor = None,
            attention_mask: Optional[torch.Tensor] = None,
            position_ids: Optional[torch.LongTensor] = None,
            past_key_values: Optional[List[torch.FloatTensor]] = None,
            inputs_embeds: Optional[torch.FloatTensor] = None,
            use_cache: Optional[bool] = None,
            output_attentions: Optional[bool] = None,
            output_hidden_states: Optional[bool] = None,
            return_dict: Optional[bool] = None,
    ) -> Union[Tuple, BaseModelOutputWithPast]:

        '''
        other code
        '''

        """
        NOTE: adapt for lookahead
        lookahead always use a rank-4 tensor for attention_mask, then a minimum adaption for lookahead is routed by the rank,
        Lookahead: generate position_ids from attention_masks and set zero elements of the mask to -inf 
        """
        if attention_mask is not None and len(attention_mask.shape) == 4:
            # with lookahead
            position_ids = torch.sum(attention_mask, dim=-1).squeeze(1) - 1
            attention_mask = (1.0-attention_mask.to(inputs_embeds.dtype)) * torch.finfo(inputs_embeds.dtype).min
        else:
            # without lookahead, reuse the original code lines
            if position_ids is None:
                device = input_ids.device if input_ids is not None else inputs_embeds.device
                position_ids = torch.arange(
                    past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
                )
                position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
            else:
                position_ids = position_ids.view(-1, seq_length).long()

            if attention_mask is None:
                attention_mask = torch.ones(
                    (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
                )
            attention_mask = self._prepare_decoder_attention_mask(
                attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
            )

Note that the above adaption can not be used for batch inference, as generated token length of different samples may be varied. Adaption for batch inference can be found in models/modeling_glm_batch.py or models/modeling_llama_batch.py. Flash-attention enhanced batch inference is on developing.

Tests

Tests can be run with:

cd pia/lookahead
pytest tests/ -s

Citations

@misc{zhao2023lookahead, title={Lookahead: An Inference Acceleration Framework for Large Language Model with Lossless Generation Accuracy}, author={Yao Zhao and Zhitian Xie and Chenyi Zhuang and Jinjie Gu}, year={2023}, eprint={2312.12728}, archivePrefix={arXiv}, primaryClass={cs.IR} }

More Repositories

1

SoloPi

SoloPi ่‡ชๅŠจๅŒ–ๆต‹่ฏ•ๅทฅๅ…ท
Java
5,736
star
2

alipay-easysdk

Alipay Easy SDK for multi-language(javaใ€c#ใ€phpใ€ts etc.) allows you to enjoy a minimalist programming experience and quickly access the various high-frequency capabilities of the Alipay Open Platform.
Java
1,099
star
3

agentUniverse

agentUniverse is a LLM multi-agent framework that allows developers to easily build multi-agent applications.
Python
799
star
4

alipay-sdk-java-all

ๆ”ฏไป˜ๅฎๅผ€ๆ”พๅนณๅฐ Alipay SDK for Java
Java
521
star
5

alipay-sdk-nodejs-all

ๆ”ฏไป˜ๅฎๅผ€ๆ”พๅนณๅฐ Alipay SDK for Node.js
TypeScript
408
star
6

mPaaS

mPaaS Demo ๅˆ้›†๏ผŒmPaaS ๆ˜ฏๆบ่‡ชไบŽๆ”ฏไป˜ๅฎ็š„็งปๅŠจๅผ€ๅ‘ๅนณๅฐใ€‚The collection of demos for mPaaS components. mPaaS is the Mobile Development Platform which oriented from Alipay.
C
323
star
7

ant-application-security-testing-benchmark

xAST่ฏ„ไปทไฝ“็ณป๏ผŒ่ฎฉๅฎ‰ๅ…จๅทฅๅ…ทไธๅ†โ€œ้ป‘็›’โ€. The xAST evaluation benchmark makes security tools no longer a "black box".
Java
301
star
8

alipay-sdk-python-all

ๆ”ฏไป˜ๅฎๅผ€ๆ”พๅนณๅฐ Alipay SDK for Python
Python
268
star
9

Owfuzz

Owfuzz: a WiFi protocol fuzzing tool
C
216
star
10

alipay-sdk-net-all

ๆ”ฏไป˜ๅฎๅผ€ๆ”พๅนณๅฐ Alipay SDK for .NET
C#
200
star
11

cvpr2020-plant-pathology

Python
170
star
12

antcloud-node-stack

่š‚่š้‡‘่ž็ง‘ๆŠ€ๅฎ˜ๆ–น Node ๆŠ€ๆœฏๆ ˆ่„šๆœฌ
JavaScript
159
star
13

financial_evaluation_dataset

Python
154
star
14

rdf-file

Rdf-Fileๆ˜ฏไธ€ไธชๅค„็†็ป“ๆž„ๅŒ–ๆ–‡ๆœฌๆ–‡ไปถ็š„ๅทฅๅ…ท็ป„ไปถ
Java
149
star
15

alipay-sdk-php-all

ๆ”ฏไป˜ๅฎๅผ€ๆ”พๅนณๅฐ Alipay SDK for PHP
PHP
146
star
16

SOFAStack

SOFAStackโ„ข (Scalable Open Financial Architecture Stack) is a collection of cloud native middleware components, which are designed to build distributed systems with high performance and reliability, and have been fully validated by mission-critical financial business scenarios.
139
star
17

Ant-Multi-Modal-Framework

Research Code for Multimodal-Cognition Team in Ant Group
Python
117
star
18

vsag

vsag is a vector indexing library used for similarity search.
C++
115
star
19

Pyraformer

Python
100
star
20

container-observability-service

Simplify Kubernetes applications operation with one-stop observability services, including resource delivery SLO๏ผŒroot cause diagnoses and container lifecycle tracing and more.
Go
88
star
21

ios-malicious-bithunter

iOS Malicious Bit Hunter is a malicious plug-in detection engine for iOS applications. It can analyze the head of the macho file of the injected dylib dynamic library based on runtime. If you are interested in other programs of the author, please visit https://github.com/SecurityLife
C
83
star
22

goldfish

A development framework for Alipay Mini Program.
TypeScript
80
star
23

SQLFlow

SQLFlow is a bridge that connects a SQL engine, e.g. MySQL, Hive, SparkSQL or SQL Server, with TensorFlow and other machine learning toolkits. SQLFlow extends the SQL language to enable model training, prediction and inference.
73
star
24

KnowledgeGraphEmbeddingsViaPairedRelationVectors_PairRE

Python
61
star
25

Antchain-MPC

Antchain-MPC is a library of MPC (Multi-Parties Computation)
Terra
57
star
26

VCSL

Video Copy Segment Localization (VCSL) dataset and benchmark [CVPR2022]
Python
49
star
27

StructuredLM_RTDT

A library for building hierarchical text representation and corresponding downstream applications.
Python
48
star
28

RJU_Ant_QA

The RJUA-QA (RenJi hospital department of Urology and Antgroup collaborative Question and Answer dataset) is an innovative medical urology specialty QA inference dataset.
47
star
29

Z-RareCharacterSolution

TypeScript
45
star
30

quic-lb

nginx-quic-lb is an implementation of ietf-quic-lb, based on nginx-release-1.18.0, you can see the detailed code in this pull request
C
41
star
31

jpmml-sparkml-lightgbm

JPMML-SparkML plugin for converting LightGBM-Spark models to PMML
Java
41
star
32

PASE

C
41
star
33

global-open-sdk-java

Ant global gateway SDK
Java
35
star
34

container-auto-tune

Container Auto Tune is an intelligent parameter tuning product that helps developers, operators automatically adjust the application, analyzes JVM reasonable configuration parameters through intelligent algorithms.Please visit the official site for the quick start guide and documentation.
Java
32
star
35

promo-mini-component

ๆ”ฏไป˜ๅฎ่ฅ้”€็Žฉๆณ•ๅฐ็จ‹ๅบ็ป„ไปถๅบ“
JavaScript
31
star
36

private_llm

Python
28
star
37

tls13-sm-spec

IETF Internet-Draft (I-D) of Chinese cipher suites in TLSv1.3 and related documentation.
Makefile
27
star
38

microservice_system_twin_graph_based_anomaly_detection

Python
26
star
39

mobile-agent

Python
26
star
40

alipay-intellij-plugin

Intellij IDEA Plugin
20
star
41

character-js

TypeScript
19
star
42

global-open-sdk-php

Ant global gateway SDK
PHP
17
star
43

ams-java-sdk

AMS Java binding
Java
13
star
44

antchain-openapi-prod-sdk

PHP
10
star
45

Pattern-Based-Compression

High-Ratio Compression for Machine-Generated Data
C
10
star
46

global-open-sdk-python

Ant global gateway SDK
Python
10
star
47

PC2-NoiseofWeb

Noise of Web (NoW) is a challenging noisy correspondence learning (NCL) benchmark containing 100K image-text pairs for robust image-text matching/retrieval models.
Python
9
star
48

YiJian-Community

YiJian-Comunity: a full-process automated large model safety evaluation tool designed for academic research
Python
9
star
49

AOP-Based-Runtime-Security-Analysis-Toolkit

TypeScript
8
star
50

ant-application-security-testing-benchmark-nodejs

JavaScript
8
star
51

agentUniverse-Guides

8
star
52

RGSL

Python
8
star
53

POA

Python
7
star
54

payment-code-widget

A lightweight library provides UI widgets to display payment code in mobile applications. The dimension of the payment code is optimal and scanner-friendly.
Java
6
star
55

TDEER

Code For TDEER: An Efficient Translating Decoding Schema for Joint Extraction of Entities and Relations (EMNLP 2021)
Python
5
star
56

Finite_State_Autoregressive_Entropy_Coding

Python
5
star
57

ComBERT

4
star
58

Parameter_Inference_Efficient_PIE

Python
4
star
59

NMCDR

Python
4
star
60

A2-efficient-automated-attacker-for-boosting-adversarial-training

Python
4
star
61

global-open-sdk-dotnet

C#
3
star
62

tldk

This is a fork of FDio/tldk.
C
3
star
63

DUPLEX

Python
3
star
64

antchain-openapi-util-sdk

C#
3
star
65

Automatic_AI_Model_Greenness_Track_Toolkit

JavaScript
3
star
66

style-tokenizer

Python
3
star
67

Timestep-aware-SentenceEmbedding-and-AcmeCoverage

Python
2
star
68

ATTEMPT_Pre-training_with_Aspect-Content_Text_Mutual_Prediction

Python
2
star
69

hypro_tpp

Python
1
star
70

BehaviorAugmentedRelevanceModel

Implementation and data of the paper "Beyond Semantics: Learning a Behavior Augmented Relevance Model with Self-supervised Learning" in CIKM'23.
Python
1
star
71

A-Knowledge-augmented-Method-DiKGRS

Python
1
star
72

Analogic-Reasoning-Augmented-Large-Language-Model

Python
1
star
73

MMDL-based-Data-Augmentation-with-Domain-Knowledge-for-Time-Series-Classification

This repository contains the official implementation for the paper: MMDL-based Data Augmentation with Domain Knowledge for Time Series Classification.
Python
1
star