• Stars
    star
    557
  • Rank 79,968 (Top 2 %)
  • Language
    Python
  • License
    Apache License 2.0
  • Created 10 months ago
  • Updated 2 months ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

FP16xINT4 LLM inference kernel that can achieve near-ideal ~4x speedups up to medium batchsizes of 16-32 tokens.

Marlin

This is Marlin, a Mixed Auto-Regressive Linear kernel (and the name of one of the planet's fastest fish), an extremely optimized FP16xINT4 matmul kernel aimed at LLM inference that can deliver close to ideal (4x) speedups up to batchsizes of 16-32 tokens (in contrast to the 1-2 tokens of prior work with comparable speedup). This makes Marlin well suited for larger-scale serving, speculative decoding or advanced multi-inference schemes such as CoT-Majority.

Techniques:

Most modern GPUs feature FLOP to byte ratios of around 100-200. Hence, as long as we perform less than 25-50 (tensor core) multiply-accumulates per 4-bit quantized weight, it should (theoretically) be possible to maintain near ideal 4x speedup over FP16 weights. This means that the full performance benefits of weight-only quantization should, in principle, extend to batchsizes 4-8x larger than what is currently achieved by existing kernels. However, actually realizing this in practice is very challenging, since we essentially need to fully utilize all available GPU resources (global memory, L2 cache, shared memory, tensor cores, vector cores), simultaneously. Marlin accomplishes this through numerous techniques and optimizations, briefly sketched below:

  • We organize computation in such a way that all activations are essentially always fetched from L2 cache and are further reused several times within registers to make sure that repeated loading from shared memory does not become a bottleneck either.
  • We execute global weight loads asynchronously, to all compute operations but also activations loads, with a cache policy that allows immediate eviction in order to not unnecessary pollute the L2 cache with values that are never reused.
  • We perform shared memory loads, whose footprint is quite significant due to relatively large activations, via double buffering to overlap them with computation and global loads.
  • We carefully order dequantization and tensor core instructions to ensure that both GPU pipelines are well saturated and do not bottleneck each other.
  • In general, both quantized weights and group scales are reshuffled offline, into a layout that gives ideal access patterns during execution, allowing for instance directly dequantizing weights into tensor core organization.
  • We have multiple warps in a threadblock compute partial results of the same output tile, in order to achieve higher warp counts, maximizing compute and latency hiding, without increasing the output tile size, which would make good partioning on realistic matrices difficult.
  • All loads use maximum vector length for peak efficiency and we also perform several layout transformations to guarantee that all shared memory reads and writes are conflict-free, in particular for matrix loading instructions, and that global reduction happens at minimal memory overhead.
  • We set up and unroll loops such that the majority of memory offsets are static, minimizing runtime index calculations.
  • We implement a "striped" paritioning scheme where the segment of tiles processed by each SM may (partially) span over multiple column "slices". This leads to good SM utlization on most matrix shapes, while minimizing required global reduction steps.
  • Global reduction happens directly in the output buffer (temporarily downcasting FP32 accumulators to FP16) which is kept in L2 cache; reduction operations are generally optimized to avoid any unnecessary reads or writes as well.
  • Overall, the kernel's PTX assembly was extensively analyzed in NSight-Compute, and the CUDA code features several more redundant or slightly suboptimal constructions that however compile to faster PTX.

Benchmarks:

We first compare the performance of Marlin with other popular 4-bit inference kernels, on a large matrix that can be ideally partioned on an NVIDIA A10 GPU. This allows all kernels to reach pretty much their best possible performance. All kernels are executed at groupsize 128 (however, we note that scale formats are not 100% identical).

While existing kernels achieve relatively close to the optimal 3.87x (note the 0.125 bits storage overhead of the group scales) speedup at batchsize 1, their performance degrades quickly as the number of inputs is increased. In contrast, Marlin delivers essentially ideal speedups at all batchsizes, enabling the maximum possible 3.87x speedup up to batchsizes around 16-32.

Due to its striped partioning scheme, Marlin brings strong performance also on real (smaller) matrices and various GPUs. This is demonstrated by the below results, where we benchmark, at batchsize 16, the overall runtime across all linear layers in Transformer blocks of popular open-source models.

Finally, we also study what performance can be sustained over longer periods of time, at locked base GPU clock. Interestingly, we find that reduced clock speeds significantly harm the relative speedups of prior kernels, but have no effect on Marlin's virtually optimal performance (relative to the lower clock setting).

Requirements:

  • CUDA >= 11.8 (in particular also for the nvcc compiler, the version of which should match with torch)
  • NVIDIA GPU with compute capability >= 8.0 (Ampere or Ada, Marlin is not yet optimized for Hopper)
  • torch>=2.0.0
  • numpy

Usage:

If all requirements are met, it should be possible to install Marlin by calling

pip install .

in the root folder of this repository.

Afterwards, the easiest way to use the Marlin kernel is via a marlin.Layer, a torch-module representing a Marlin quantized layer. It allows converting a "fake-quantized" (dequantized values stored in FP16) torch.Linear layer into the compressed Marlin format via marlin.Layer.pack(linear, scales). Alternatively, the kernel can also be called directly through marlin.mul(..), provided that weights and scales have already been appropriately preprocessed (see marlin.Layer.pack(...)). The kernel itself can be found in the self-contained marlin/marlin_cuda_kernel.cu file, which does not contain any dependencies beyond base-CUDA and should thus be easy to integrate into other lower-level frameworks.

Correctness tests can be executed via python test.py and benchmarks via python bench.py. Please note that in order to reproduce our "sustainable performance" benchmarks, the GPU clocks need to be locked to their respective base values using:

sudo nvidia-smi --lock-gpu-clocks=BASE_GPU_CLOCK --lock-memory-clocks=BASE_MEM_CLOCK

Additionally, if ECC is enabled (e.g., on an A10), then the maximum achievable memory bandwidth will be 10-15% lower than in the official spec sheet as every memory requests will contain checksum overheads. This can be disabled via

sudo nvidia-smi -e 0

which we do in our A10 benchmarks.

GPTQ Example:

In the gptq subfolder, we also provide a slightly improved version of the GPTQ algorithm, with better group grid clipping and non-uniform calibration sample length, that can produce Marlin-compatible 4-bit versions of Llama2 models. Additionally, there is a script to evaluate such compressed models (using Marlin kernels) in the popular LLM eval harness. Here are corresponding sample commands (marlin, transformers and datasets packages must be installed):

% Compress Llama2 model and export model in Marlin format.
python llama.py LLAMA2_CHECKPOINT --wbits 4 --save checkpoint.pt
% Perform perplexity evaluation of uncompressed model.
python llama.py LLAMA2_CHECKPOINT
% Evaluate compressed model (with Marlin kernels) in the eval harness.
python eval.py --model hf --model_args pretrained=LLAMA2_CHECKPOINT --tasks mmlu \
  --marlin_checkpoint checkpoint.marlin.g128
% Evaluate full precision baseline.
python eval.py --model hf --model_args pretrained=LLAMA2_CHECKPOINT --tasks mmlu 

We measure the following WikiText and Red-Pajama perplexities, as well as MMLU zero-shot accuracy, for 4-bit (group=128) Marlin models:

Llama2 Wiki2 (FP16) Wiki2 (INT4) RedPaj (FP16) RedPaj (INT4) MMLU (FP16) MMLU (INT4)
7B 5.12 5.27 6.14 6.30 41.80 40.07
13B 4.57 4.67 5.67 5.79 52.10 51.13
70B 3.12 3.21 4.74 4.81 65.43 64.81

We note that this GPTQ example is currently intended mostly as a demonstration of how to produce accurate Marlin models and as an end-to-end validation of kernel correctness (rather than to be a flexible compression tool).

Cite:

If you found this work useful, please consider citing:

@misc{frantar2024marlin,
  author = {Frantar, Elias and Alistarh, Dan},
  title = {Marlin: a fast 4-bit inference kernel for medium batchsizes},
  year = {2024},
  publisher = {GitHub},
  journal = {GitHub repository},
  howpublished = {\url{https://github.com/IST-DASLab/marlin}}
}

More Repositories

1

gptq

Code for the ICLR 2023 paper "GPTQ: Accurate Post-training Quantization of Generative Pretrained Transformers".
Python
1,889
star
2

sparsegpt

Code for the ICML 2023 paper "SparseGPT: Massive Language Models Can Be Accurately Pruned in One-Shot".
Python
694
star
3

qmoe

Code for the paper "QMoE: Practical Sub-1-Bit Compression of Trillion-Parameter Models".
Python
258
star
4

PanzaMail

Python
256
star
5

QUIK

Repository for the QUIK project, enabling the use of 4bit kernels for generative inference
C++
169
star
6

OBC

Code for the NeurIPS 2022 paper "Optimal Brain Compression: A Framework for Accurate Post-Training Quantization and Pruning".
Python
95
star
7

WoodFisher

Code accompanying the NeurIPS 2020 paper: WoodFisher (Singh & Alistarh, 2020)
Python
46
star
8

Sparse-Marlin

Boosting 4-bit inference kernels with 2:4 Sparsity
Cuda
46
star
9

SparseFinetuning

Repository for Sparse Finetuning of LLMs via modified version of the MosaicML llmfoundry
Python
36
star
10

RoSA

Python
29
star
11

QIGen

Repository for CPU Kernel Generation for LLM Inference
Python
25
star
12

ACDC

Code for reproducing "AC/DC: Alternating Compressed/DeCompressed Training of Deep Neural Networks" (NeurIPS 2021)
Python
20
star
13

spdy

Code for ICML 2022 paper "SPDY: Accurate Pruning with Speedup Guarantees"
Python
18
star
14

M-FAC

Efficient reference implementations of the static & dynamic M-FAC algorithms (for pruning and optimization)
Python
16
star
15

torch_cgx

Pytorch distributed backend extension with compression support
C++
14
star
16

sparseprop

C++
13
star
17

peft-rosa

A fork of the PEFT library, supporting Robust Adaptation (RoSA)
Python
13
star
18

MicroAdam

This repository contains code for the MicroAdam paper.
Python
10
star
19

sparse-imagenet-transfer

Code for reproducing the results in "How Well do Sparse Imagenet Models Transfer?", presented at CVPR 2022
Python
8
star
20

CrAM

Code for reproducing the results from "CrAM: A Compression-Aware Minimizer" accepted at ICLR 2023
Python
8
star
21

spops

C++
6
star
22

ISTA-DASLab-Optimizers

Python
5
star
23

EFCP

The repository contains code to reproduce the experiments from our paper Error Feedback Can Accurately Compress Preconditioners available below:
Python
4
star
24

pruned-vision-model-bias

Code for reproducing the paper "Bias in Pruned Vision Models: In-Depth Analysis and Countermeasures"
Jupyter Notebook
4
star
25

Mathador-LM

Code for the paper "Mathador-LM: A Dynamic Benchmark for Mathematical Reasoning on LLMs".
Python
4
star
26

CAP

Repository for Correlation Aware Prune (NeurIPS23) source and experimental code
Python
4
star
27

evolution-strategies

Python
2
star
28

TACO4NLP

Task aware compression for various NLP tasks
Python
2
star
29

smart-quantizer

Repository for Vitaly's implementation of the distribution-adaptive quantizer
Python
1
star
30

ZipLM

Code for the NeurIPS 2023 paper: "ZipLM: Inference-Aware Structured Pruning of Language Models".
1
star
31

QRGD

Repository for the implementation of "Distributed Principal Component Analysis with Limited Communication" (Alimisis et al., NeurIPS 2021). Parts of this code were originally based on code from "Communication-Efficient Distributed PCA by Riemannian Optimization" (Huang and Pan, ICML 2020).
MATLAB
1
star
32

KDVR

Code for the experiments in Knowledge Distillation Performs Partial Variance Reduction, NeurIPS 2023
Python
1
star
33

GridSearcher

GridSearcher simplifies running grid searches for machine learning projects in Python, emphasizing parallel execution and GPU scheduling without dependencies on SLURM or other workload managers.
Python
1
star