Soundstorm - Pytorch
Implementation of SoundStorm, Efficient Parallel Audio Generation from Google Deepmind, in Pytorch.
They basically applied MaskGiT to the residual vector quantized codes from Soundstream. The transformer architecture they chose to use is one that fits well with the audio domain, named Conformer
Appreciation
-
Stability and š¤ Huggingface for their generous sponsorships to work on and open source cutting edge artificial intelligence research
-
Lucas Newman for numerous contributions, including the initial training code, acoustic prompting logic, per-level quantizer decoding!
-
š¤ Accelerate for providing a simple and powerful solution for training
-
Einops for the indispensable abstraction that makes building neural networks fun, easy, and uplifting
-
Steven Hillis for submitting the correct masking strategy and for verifying that the repository works! š
-
Lucas Newman for basically training a small working Soundstorm with models across multiple repositories, showing it all works end-to-end. Models include SoundStream, Text-to-Semantic T5, and finally the SoundStorm transformer here.
Install
$ pip install soundstorm-pytorch
Usage
import torch
from soundstorm_pytorch import SoundStorm, ConformerWrapper
conformer = ConformerWrapper(
codebook_size = 1024,
num_quantizers = 12,
conformer = dict(
dim = 512,
depth = 2
),
)
model = SoundStorm(
conformer,
steps = 18, # 18 steps, as in original maskgit paper
schedule = 'cosine' # currently the best schedule is cosine
)
# get your pre-encoded codebook ids from the soundstream from a lot of raw audio
codes = torch.randint(0, 1024, (2, 1024, 12)) # (batch, seq, num residual VQ)
# do the below in a loop for a ton of data
loss, _ = model(codes)
loss.backward()
# model can now generate in 18 steps. ~2 seconds sounds reasonable
generated = model.generate(1024, batch_size = 2) # (2, 1024)
To directly train on raw audio, you need to pass in your pretrained SoundStream
into SoundStorm
. You can train your own SoundStream
at audiolm-pytorch.
import torch
from soundstorm_pytorch import SoundStorm, ConformerWrapper, Conformer, SoundStream
conformer = ConformerWrapper(
codebook_size = 1024,
num_quantizers = 12,
conformer = dict(
dim = 512,
depth = 2
),
)
soundstream = SoundStream(
codebook_size = 1024,
rq_num_quantizers = 12,
attn_window_size = 128,
attn_depth = 2
)
model = SoundStorm(
conformer,
soundstream = soundstream # pass in the soundstream
)
# find as much audio you'd like the model to learn
audio = torch.randn(2, 10080)
# course it through the model and take a gazillion tiny steps
loss, _ = model(audio)
loss.backward()
# and now you can generate state-of-the-art speech
generated_audio = model.generate(seconds = 30, batch_size = 2) # generate 30 seconds of audio (it will calculate the length in seconds based off the sampling frequency and cumulative downsamples in the soundstream passed in above)
Complete text-to-speech will rely on a trained TextToSemantic
encoder / decoder transformer. You will then load the weights and pass it into the SoundStorm
as spear_tts_text_to_semantic
This is a work-in-progress, as spear-tts-pytorch
only has the model architecture complete, and not the pretraining + pseudo-labeling + backtranslation logic.
from spear_tts_pytorch import TextToSemantic
text_to_semantic = TextToSemantic(
dim = 512,
source_depth = 12,
target_depth = 12,
num_text_token_ids = 50000,
num_semantic_token_ids = 20000,
use_openai_tokenizer = True
)
# load the trained text-to-semantic transformer
text_to_semantic.load('/path/to/trained/model.pt')
# pass it into the soundstorm
model = SoundStorm(
conformer,
soundstream = soundstream,
spear_tts_text_to_semantic = text_to_semantic
).cuda()
# and now you can generate state-of-the-art speech
generated_speech = model.generate(
texts = [
'the rain in spain stays mainly in the plain',
'the quick brown fox jumps over the lazy dog'
]
) # (2, n) - raw waveform decoded from soundstream
Todo
-
integrate soundstream
-
when generating, and length can be defined in seconds (takes into sampling freq etc)
-
make sure grouped rvq is supported. concat embeddings rather than sum across group dimension
-
just copy conformer over and redo shaw's relative positional embedding with rotary embedding. nobody uses shaw anymore.
-
default flash attention to true
-
remove batchnorm, and just use layernorm, but after the swish (as in normformer paper)
-
trainer with accelerate - thanks to @lucasnewman
-
allow for variable lengthed sequence training and generation, by passing in
mask
atforward
andgenerate
-
option to return list of audio files when generating
-
turn it into a command line tool
-
add cross attention and adaptive layernorm conditioning
Citations
@misc{borsos2023soundstorm,
title = {SoundStorm: Efficient Parallel Audio Generation},
author = {ZalƔn Borsos and Matt Sharifi and Damien Vincent and Eugene Kharitonov and Neil Zeghidour and Marco Tagliasacchi},
year = {2023},
eprint = {2305.09636},
archivePrefix = {arXiv},
primaryClass = {cs.SD}
}
@inproceedings{dao2022flashattention,
title = {Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
author = {Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
booktitle = {Advances in Neural Information Processing Systems},
year = {2022}
}
@article{Chang2022MaskGITMG,
title = {MaskGIT: Masked Generative Image Transformer},
author = {Huiwen Chang and Han Zhang and Lu Jiang and Ce Liu and William T. Freeman},
journal = {2022 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
year = {2022},
pages = {11305-11315}
}
@article{Lezama2022ImprovedMI,
title = {Improved Masked Image Generation with Token-Critic},
author = {Jos{\'e} Lezama and Huiwen Chang and Lu Jiang and Irfan Essa},
journal = {ArXiv},
year = {2022},
volume = {abs/2209.04439}
}
@inproceedings{Nijkamp2021SCRIPTSP,
title = {SCRIPT: Self-Critic PreTraining of Transformers},
author = {Erik Nijkamp and Bo Pang and Ying Nian Wu and Caiming Xiong},
booktitle = {North American Chapter of the Association for Computational Linguistics},
year = {2021}
}
@inproceedings{rogozhnikov2022einops,
title = {Einops: Clear and Reliable Tensor Manipulations with Einstein-like Notation},
author = {Alex Rogozhnikov},
booktitle = {International Conference on Learning Representations},
year = {2022},
url = {https://openreview.net/forum?id=oapKSVM2bcj}
}
@misc{su2021roformer,
title = {RoFormer: Enhanced Transformer with Rotary Position Embedding},
author = {Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu},
year = {2021},
eprint = {2104.09864},
archivePrefix = {arXiv},
primaryClass = {cs.CL}
}