• Stars
    star
    305
  • Rank 136,879 (Top 3 %)
  • Language
    Python
  • Created over 1 year ago
  • Updated about 1 year ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Fast inference from large lauguage models via speculative decoding

Fast inference from transformers via speculative decoding

This repository implements speculative sampling for large language model (LLM) decoding. It utilizes two models during the decoding process: a target model and an approximation model. The approximation model is a smaller model, while the target model is a larger one. The approximation model generates token guesses, and the target model corrects these guesses. This approach allows for decoding by running the target model in parallel on the outputs of the approximation models, resulting in improved efficiency compared to decoding with the target model alone.

The speculative sampling is proposed by Google and Deepmind independently. So I implement two slightly different versions of speculative sampling: Google's and Deepmind's.

Update Logs

  • 2023.09.21: Add serving features. Support more models, i.e. llama-7B and llama-1B.

  • 2023.09.19: Add KV Cache Optimization to the Google's version.

  • 2023.08.16: First release, implement the paper's algorithm. Support Bloom-560M and Bloomz-7B1.

Usage

Inference

You need prepare a pair of models using the same embedding and vocabulary. The approximation model should be smaller than the target model. Here are some tested model pairs.

In the sample, we demostrate bloomz-7b1 as the target model, bloom-560m as the approximation model.

python main.py \
    --input "The quick brown fox jumps over the lazy " \
    --target_model_name bigscience/bloomz-7b1 \
    --approx_model_name bigscience/bloom-560m

You can also use -v args to see a token is generated by which model.

example image

I recommand you to use llama2-7B and llama2-70B as the approximation and target model respectively. I did observe speedup on this case as shown in the following. Note the choice of approx model and target model are essential for the speedup. The speedup will not be observed in the following cases: If the models are both small ones, the speedup will not be observed since the speed differences are not significant. If the model size difference is too large, more rejection and resampling will occure. Also the sampling logic is not efficient enough. I noticed substantial overhead is on Softmax and Layernorm. I will try to optimize it in the future. Do not histant to open an idea on performance improvements.

llama2-7b llama2-70b Speculative
speed(tokens/sec) 1084.86 329.83 427.02

Serving

Start an inference server.

python serving.py

Test the serving with curl:

curl -X POST -H "Content-Type: application/json" -d '{"prompt": "Who is the president of the USA"}' http://127.0.0.1:5000/predict

References

@inproceedings{leviathan2023fast,
  title={Fast inference from transformers via speculative decoding},
  author={Leviathan, Yaniv and Kalman, Matan and Matias, Yossi},
  booktitle={International Conference on Machine Learning},
  pages={19274--19286},
  year={2023},
  organization={PMLR}
}

@article{chen2023accelerating,
  title={Accelerating large language model decoding with speculative sampling},
  author={Chen, Charlie and Borgeaud, Sebastian and Irving, Geoffrey and Lespiau, Jean-Baptiste and Sifre, Laurent and Jumper, John},
  journal={arXiv preprint arXiv:2302.01318},
  year={2023}
}

Limitations

Currently, I only support request of batch size as 1. Since this repo is built for demostration purpose, other optimizations, such as batching and parallelism, are not included which are essential for efficiency.

More Repositories

1

SWCaffe

A Deep Learning Framework customized for Sunway TaihuLight
C++
39
star
2

Distributed-ResNet-Tensorflow

A Distributed ResNet on multi-machines each with one GPU card.
Python
20
star
3

swGEMM

A highly efficient library for GEMM operations on Sunway TaihuLight
C
14
star
4

swDNN

a highly-efficient library for deep neural networks based on Sunway TaihuLight supercomputer.
Roff
14
star
5

PSTensor

PSTensor provides a way to hack the memory management of tensors in TensorFlow and PyTorch by defining your own C++ Tensor Class.
C++
9
star
6

PyTorchMemTracer

Depict GPU memory footprint during DNN training of PyTorch
Python
9
star
7

ChituAttention

Quantized Attention on GPU
Python
8
star
8

ColoBloom

Python
5
star
9

intel-baidu-allreduce

C++
5
star
10

DeepSpeedZeRO3Benchmark

A finetuned benchmark scripts for DeepSpeed zero3 stage
Python
5
star
11

swDNNv1.0

A Deep Learning Library for Sunway TaihuLight
C
4
star
12

crack_leetcode

五天刷题,三天模拟!快速掌握leetcode解题套路!
C++
4
star
13

ssh-passwd-free

Method to set passwd-free for a set of IPs
Shell
3
star
14

TensorrtBenchmark

Benchmark bert using TensorRT
C++
3
star
15

SMO-SVM

a python implementation of libsvm
Perl
3
star
16

cudaMemHook

C++
3
star
17

horovod-resnet

Python
3
star
18

Communication-Efficient-DNN

Python
3
star
19

DiTKVAnalysis

An auxiliary project analysis of the characteristics of KV in DiT Attention.
Python
2
star
20

89757

Python
2
star
21

DeepGlobe

Python
2
star
22

DTensor

Study PyTorch DTensor
Python
2
star
23

MoE-Megatron-LM

Python
2
star
24

large-scale-tensorflow-benchmark

benchmark tensorflow for supercomputers
Jupyter Notebook
2
star
25

ProjectRun

1
star
26

CommTest

Test for PyTorch Async Collective Communication
Python
1
star
27

ColossalAI_bert_inference

Python
1
star
28

ckp_training

Python
1
star
29

ADMM-NeuralNetwork

ADMM-NeuralNetwork was implemented by a potato
MATLAB
1
star