• Stars
    star
    432
  • Rank 100,650 (Top 2 %)
  • Language
    Python
  • License
    Apache License 2.0
  • Created about 3 years ago
  • Updated 11 months ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

A PyTorch implementation of Perceiver, Perceiver IO and Perceiver AR with PyTorch Lightning scripts for distributed training

Perceiver, Perceiver IO and Perceiver AR

This repository is a PyTorch implementation of Perceiver, Perceiver IO and Perceiver AR, with PyTorch Lightning interfaces for model training and Hugging Face 🤗 interfaces for inference.

Perceiver: General Perception with Iterative Attention (paper, video) Perceiver
Perceiver IO: A General Architecture for Structured Inputs & Outputs (paper, blog post) Perceiver IO
General-purpose, long-context autoregressive modeling with Perceiver AR (paper, blog post) Perceiver AR

Overview

Core of the perceiver-io library are backend models, lightweight PyTorch implementations of Perceiver, Perceiver IO and Perceiver AR. They can be wrapped into PyTorch Lightning modules for training (Lightning interface) and 🤗 modules for inference (Hugging Face interface). See library design for details.

library-design

The command line interface for training is implemented with Lightning CLI. Training datasets are 🤗 datasets wrapped into PyTorch Lightning data modules. For NLP tasks, perceiver-io supports all 🤗 fast tokenizers and the 🤗 Perceiver UTF-8 bytes tokenizer.

Documentation

Installation

Via pip

pip install perceiver-io[text,vision,audio]

From sources

Installation from sources requires a Miniconda and a Poetry (1.2.0 or higher) installation.

Create and activate the perceiver-io conda environment:

conda env create -f environment.yml
conda activate perceiver-io

Install main and test dependencies, including all extras:

# Without dependencies required for examples
poetry install --all-extras

If you want to run the examples locally, additionally use --with examples:

poetry install --all-extras --with examples

Docker image

docker pull ghcr.io/krasserm/perceiver-io:latest

See Docker image for details.

Getting started

Inference

Optical flow

Compute the optical flow between consecutive frames of an input video and write the rendered results to an output video:

from urllib.request import urlretrieve
from transformers import pipeline

from perceiver.data.vision import video_utils
from perceiver.model.vision import optical_flow  # register auto-classes and pipeline

urlretrieve(
    url="https://martin-krasser.com/perceiver/flow/sintel_clip_cave_dragon_fight.mp4",
    filename="sintel_clip_cave_dragon_fight.mp4",
)

# Create optical flow pipeline
optical_flow_pipeline = pipeline("optical-flow", model="krasserm/perceiver-io-optical-flow", device="cuda:0")

# load consecutive video frame pairs
frame_pairs = video_utils.read_video_frame_pairs("sintel_clip_cave_dragon_fight.mp4")

# create and render optical flow for all frame pairs
optical_flows = optical_flow_pipeline(frame_pairs, render=True, device="cuda:0")

# create video with rendered optical flows
video_utils.write_video("sintel_clip_cave_dragon_fight_output.mp4", optical_flows, fps=24)

Here is a side-by-side comparison of the input and output video:

optical-flow-sbs

Symbolic audio generation

Create audio sequences by generating symbolic (MIDI) audio data and converting the generated audio symbols into WAV output using fluidsynth (Note: fluidsynth must be installed in order for the following example to work):

from transformers import pipeline
from pretty_midi import PrettyMIDI
from perceiver.model.audio import symbolic  # auto-class registration

repo_id = "krasserm/perceiver-ar-sam-giant-midi"

prompt = PrettyMIDI("prompt.mid")
audio_generator = pipeline("symbolic-audio-generation", model=repo_id)

output = audio_generator(prompt, max_new_tokens=64, num_latents=1, do_sample=True, top_p=0.95, temperature=1.0, render=True)

with open("generated_audio.wav", "wb") as f:
    f.write(output["generated_audio_wav"])

Examples of generated audio sequences are available on the 🤗 hub.

See inference examples for more examples.

Training

Train a small Perceiver IO image classifier (907K parameters) on MNIST from the command line. The classifier cross-attends to individual pixels of input images with repeated cross-attention. See image classification training example for more details.

python -m perceiver.scripts.vision.image_classifier fit \
  --model.num_latents=32 \
  --model.num_latent_channels=128 \
  --model.encoder.num_frequency_bands=32 \
  --model.encoder.num_cross_attention_layers=2 \
  --model.encoder.num_self_attention_blocks=3 \
  --model.encoder.num_self_attention_layers_per_block=3 \
  --model.encoder.first_self_attention_block_shared=false \
  --model.encoder.dropout=0.1 \
  --model.encoder.init_scale=0.1 \
  --model.decoder.num_output_query_channels=128 \
  --model.decoder.dropout=0.1 \
  --model.decoder.init_scale=0.1 \
  --data=MNISTDataModule \
  --data.batch_size=64 \
  --optimizer=AdamW \
  --optimizer.lr=1e-3 \
  --lr_scheduler.warmup_steps=500 \
  --trainer.accelerator=gpu \
  --trainer.devices=1 \
  --trainer.max_epochs=30 \
  --trainer.logger=TensorBoardLogger \
  --trainer.logger.save_dir=logs \
  --trainer.logger.name=logs

Model construction describes how to implement model-specific command line interfaces with the Lightning CLI. Training checkpoints are written to the logs/img_clf/version_0/checkpoints directory. Assuming a checkpoint with filename epoch=025-val_loss=0.065.ckpt exists, it can be converted to a perceiver-io 🤗 model with

from perceiver.model.vision.image_classifier import convert_mnist_classifier_checkpoint

convert_mnist_classifier_checkpoint(
    save_dir="example/mnist-classifier",
    ckpt_url="logs/img_clf/version_0/checkpoints/epoch=025-val_loss=0.065.ckpt",
)

so that it can be used in a 🤗 image classification pipeline

from datasets import load_dataset
from transformers import pipeline

mnist_dataset = load_dataset("mnist", split="test")[:9]

images = mnist_dataset["image"]
labels = mnist_dataset["label"]

classifier = pipeline("image-classification", model="example/mnist-classifier")
predictions = [pred[0]["label"] for pred in classifier(images)]

print(f"Labels:      {labels}")
print(f"Predictions: {predictions}")
Labels:      [7, 2, 1, 0, 4, 1, 4, 9, 5]
Predictions: [7, 2, 1, 0, 4, 1, 4, 9, 5]

or loaded directly:

import torch
from transformers import AutoModelForImageClassification, AutoImageProcessor

model = AutoModelForImageClassification.from_pretrained("example/mnist-classifier")
processor = AutoImageProcessor.from_pretrained("example/mnist-classifier")

inputs = processor(images, return_tensors="pt")

with torch.no_grad():
    # use perceiver-io Hugging Face model
    output_1 = model(**inputs).logits

with torch.no_grad():
    # or use perceiver-io backend model directly  
    output_2 = model.backend_model(inputs.pixel_values)

print(f"Predictions: {output_1.argmax(dim=-1).numpy().tolist()}")
print(f"Predictions: {output_2.argmax(dim=-1).numpy().tolist()}")
Predictions: [7, 2, 1, 0, 4, 1, 4, 9, 5]
Predictions: [7, 2, 1, 0, 4, 1, 4, 9, 5]

See training examples for more examples.

Articles

Articles referencing this repository:

Other implementations

More Repositories

1

bayesian-machine-learning

Notebooks about Bayesian methods for machine learning
Jupyter Notebook
1,808
star
2

super-resolution

Tensorflow 2.x based implementation of EDSR, WDSR and SRGAN for single image super-resolution
Python
1,496
star
3

face-recognition

Deep face recognition with Keras, Dlib and OpenCV
Jupyter Notebook
377
star
4

machine-learning-notebooks

Stanford Machine Learning course exercises implemented with scikit-learn
Jupyter Notebook
341
star
5

fairseq-image-captioning

Transformer-based image captioning extension for pytorch/fairseq
Python
313
star
6

streamz

A combinator library for integrating Functional Streams for Scala (FS2), Akka Streams and Apache Camel
Scala
283
star
7

akka-analytics

Large-scale event processing with Akka Persistence and Apache Spark
Scala
274
star
8

akka-persistence-cassandra

A replicated Akka Persistence journal backed by Apache Cassandra
Scala
224
star
9

akka-persistence-kafka

A replicated Akka Persistence journal backed by Apache Kafka
Scala
201
star
10

akka-stream-eventsourcing

Event sourcing for Akka Streams
Scala
101
star
11

grails-jaxrs

JAX-RS Plugin for Grails
Groovy
50
star
12

scalaz-camel

A Scala(z)-based DSL for Apache Camel
Scala
50
star
13

ipf

Open eHealth Integration Platform
Java
35
star
14

akka-persistence-testkit

Compatibility testkit for Akka Persistence storage plugins
Scala
21
star
15

bot-with-plan

Separation of planning concerns in ReAct-style LLM agents. Planner fine-tuning on synthetic trajectories.
Python
10
star
16

krasserm.github.io

Jupyter Notebook
9
star
17

camelinaction-appendix-e

akka-camel examples from book Camel in Action - Appendix E (adjusted to the most recent Akka release or development snapshot)
Scala
8
star
18

machine-learning-minis

Minimalistic example code for various machine learning and deep learning topics
Jupyter Notebook
8
star
19

ipf-labs

eHealth Integration Framework Labs
Java
7
star
20

ipf-runtime

OSGi-based runtime environment for IPF applications
Shell
6
star
21

ipf-tools

eHealth Integration Framework Tools
Java
6
star
22

sagemaker-tutorial

Multi-node, multi-GPU training with PyTorch Lightning on SageMaker
Python
5
star
23

eventuate-crdt-example

Example application that uses Eventuate's operation-based CRDTs
Scala
3
star
24

safr

Security Annotation Framework
Java
1
star