Dynamical Variational Autoencoders A Comprehensive Review
This repository contains the code for:
Dynamical Variational Autoencoders: A Comprehensive Review, Foundations and Trends in Machine Learning, 2021.
Laurent Girin, Simon Leglaive, Xiaoyu Bie, Julien Diard, Thomas Hueber, Xavier Alameda-Pineda
[arXiv] [Paper] [Project] [Tutorial]
More precisely, this repo is a re-implementation of the following models in Pytorch:
- VAE, Kingma et al., ICLR 2014
- DKF, Krishnan et al., AAAI 2017
- KVAE, Fraccaro et al., NeurIPS 2017
- STORN, Bayer et al., arXiv 2014
- VRNN, Chung et al., NeurIPS 2015
- SRNN, Fraccaro et al., NeurIPS 2016
- RVAE, Simon et al., ICASSP 2020
- DSAE, Yingzhen et al. ICML 2018
For the results we report at Interspeech 2021, please visit the interspeech branch
We don't report the results of KVAE since we haven't make it work in our experiments, we still provide the code for research purpose
Pretrained models
You could download all the pre-trained DVAE models here
Prerequest
The PESQ value we report in our paper is a narrow-band PESQ value provide by pypesq package. If you want to get a wide-band PESQ value, please use pesq package instead
Dataset
In this version, DVAE models support two differnt data structure:
- WSJ0, an audio speech data, we use the subset ChiME2-WSJ0 from ChiME-Challenge
- Human3.6M, a 3D human motion data under license here, the exponential map version can be download here
If you want to use our models in other datasets, you can simply modify/re-write the dataloader and make minor changes in the training steps. Please remind that DVAE models accept data in the format of
(seq_len, batch_size, x_dim)
Train
We provide all configuration examples of the above models in ./confg
# Train on DVAE (for example)
python train_model.py --cfg ./config/speech/cfg_rvae_Causal.ini
python train_model.py --cfg ./config/motion/cfg_srnn.ini
# Train DVAE with schedule sampling, w/o. pretrained model
python train_model.py --ss --cfg ./confgi/speech/cfg_srnn_ss.ini --use_pretrain --pretrain_dict /PATH_PRETRAIN_DIR
# Train DVAE with schedule sampling, w. pretrained model
python train_model.py --ss --cfg ./confgi/speech/cfg_srnn_ss.ini
# Resume training
python train_model.py --cfg ./config/speech/cfg_rvae_Causal.ini --reload --model_dir /PATH_RELOAD_DIR
Evaluation
# Evaluation on speech data
python eval_wsj.py --cfg PATH_TO_CONFIG --saved_dict PATH_TO_PRETRAINED_DICT
python eval_wsj.py --ss --cfg PATH_TO_CONFIG --saved_dict PATH_TO_PRETRAINED_DICT # schedule sampling
# Evaluation on human motion data
python eval_h36m.py --cfg PATH_TO_CONFIG --saved_dict PATH_TO_PRETRAINED_DICT
python eval_h36m.py --ss --cfg PATH_TO_CONFIG --saved_dict PATH_TO_PRETRAINED_DICT # schedule sampling
Bibtex
If you find this code useful, please star the project and consider citing:
@article{dvae2021,
title={Dynamical Variational Autoencoders: A Comprehensive Review},
author={Girin, Laurent and Leglaive, Simon and Bie, Xiaoyu and Diard, Julien and Hueber, Thomas and Alameda-Pineda, Xavier},
journal={Foundations and Trends® in Machine Learning},
year = {2021},
volume = {15},
doi = {10.1561/2200000089},
issn = {1935-8237},
number = {1-2},
pages = {1-175}
}
@inproceedings{bie21_interspeech,
author={Xiaoyu Bie and Laurent Girin and Simon Leglaive and Thomas Hueber and Xavier Alameda-Pineda},
title={{A Benchmark of Dynamical Variational Autoencoders Applied to Speech Spectrogram Modeling}},
year=2021,
booktitle={Proc. Interspeech 2021},
pages={46--50},
doi={10.21437/Interspeech.2021-256}
}
Main results
For speech data, using:
- training dataset: wsj0_si_tr_s
- validation dataset: wsj0_si_dt_05
- test dataset: wsj0_si_et_05
DVAE | SI-SDR(dB) | PESQ | ESTOI |
---|---|---|---|
VAE | 5.3 | 2.97 | 0.83 |
DKF | 9.3 | 3.53 | 0.91 |
STORN | 6.9 | 3.42 | 0.90 |
VRNN | 10.0 | 3.61 | 0.92 |
SRNN | 11.0 | 3.68 | 0.93 |
RVAE-Causal | 9.0 | 3.49 | 0.90 |
RVAE-NonCausal | 8.9 | 3.58 | 0.91 |
DSAE | 9.2 | 3.55 | 0.91 |
SRNN-TF-GM | -1.0 | 1.93 | 0.64 |
SRNN-GM | 7.8 | 3.37 | 0.88 |
For human motion data, using:
- training dataset: S1, S6, S7, S8, S9
- validation dataset: S5
- test dataset: S11
DVAE | MPJPE (mm) |
---|---|
VAE | 48.69 |
DKF | 42.21 |
STORN | 9.47 |
VRNN | 9.22 |
SRNN | 7.86 |
RVAE-Causal | 31.09 |
RVAE-NonCausal | 28.59 |
DSAE | 28.61 |
SRNN-TF-GM | 221.87 |
SRNN-GM | 43.98 |
More results can be found in Chapter 13 Experiments of our article.
Contact
For any further questions, you can drop me an email via xiaoyu[dot]bie[at]inria[dot]fr