DrummerNet
This is supplementary material of "Deep Unsupervised Drum Transcription" by Keunwoo Choi and Kyunghyun Cho, ISMIR 2019 (Delft, Netherland).
Paper on arXiv | Blog post | Poster
- What we provide: Pytorch implementation for the paper
- What we do not provide:
- pre-trained model
- drum stems that we used for the training
Installation
If you're using conda and wanna run it DrummerNet CPU, make sure it installs mkl because we'll need its fft module.
conda install -c anaconda mkl
Then,
pip install -r requirements.txt
Using conda, it would be something like this, but customize it yourself!
conda install -c pytorch pytorch torchvision
Python3
required.
Preparation
Wav files for Drum Synthesizer
-
data_drum_sources
: folder for isolated drum sources. 12 kits x 11 drum components are included. If you want to add more drum sources,- Add files and update
globals.py
accordingly.
# These names are matched with file names in data_drum_sources DRUM_NAMES = ["KD_KD", "SD_SD", "HH_CHH", "HH_OHH", "HH_PHH", "TT_HIT", "TT_MHT", "TT_HFT", "CY_RDC", "CY_CRC", "OT_TMB"] N_DRUM_VSTS = 12
- Note that as shown in
inst_src_sec.get_instset_drum()
, the last drum kit will be used in the test time only.
- Add files and update
Training files
We unfortunately cannot provide the drum-stems that we used for the trained network in the paper.
/data_drumstems
: nearly blank folder, placeholder for training data. I put one wav file andfiles.txt
as an minimum working example.- Mark Cartwright's and Richard Vogl's papers/codes provide a way to synthesize large-scale drum stems
Evaluation files, e.g., SMT
- It is not part of the code, you have to download/process it by yourself.
- First, download SMT dataset (320.7MB)
- Unzip it. Let's call the unzipped folder PATH_UNZIP
- Then run
$ python3 drummernet/eval_import_smt.py PATH_UNZIP
. E.g.,$ cd drummernet $ python3 eval_import_smt.py ~/Downloads/SMT_DRUMS/ Processing annotations... Processing audio file - copying it... all done! check out if everything's fine at data_evals/SMT_DRUMS
data_evals
: blank, placeholder for evaluation datasets
Training
- If you prepared evaluation files
python3 main.py --eval false -ld spectrum --exp_name temp_exp --metrics mae
- Otherwise,
python3 main.py --eval true -ld spectrum --exp_name temp_exp --metrics mae
If everything's fine, you'll see..
$ cd drummernet
$ python3 main.py --eval True -ld spectrum --exp_name temp_exp --metrics mae
Add arguments..
Namespace(activation='elu', batch_size=32, compare_after_hpss=False, conv_bias=False, eval=False, exp_name='temp_exp', kernel_size=3, l1_reg_lambda=0.003, learning_rate=0.0004, loss_domains=['spectrum'], metrics=['mae'], n_cqt_bins=12, n_layer_dec=6, n_layer_enc=10, n_mels=None, num_channel=50, recurrenter='three', resume=False, resume_num='', scale_r=2, source_norm='sqrsum', sparsemax_lst=64, sparsemax_type='multiply')
| With a sampling rate of 16000 Hz,
| the deepest encoded signal: 1 sample == 64 ms.
| At predicting impulses, which is done at u_conv3, 1 sample == 1 ms.
| and sparsemax_lst=64 samples at the same, at=`r` level
n_notes: 11, n_vsts:{'KD_KD': 11, 'SD_SD': 11, 'HH_CHH': 11, 'HH_OHH': 11, 'HH_PHH': 11, 'TT_HIT': 11, 'TT_MHT': 11, 'TT_HFT': 11, 'CY_RDC': 11, 'CY_CRC': 11, 'OT_TMB': 11}
then you'll see the model details.
DrummerHalfUNet(
(unet): ValidAutoUnet(
(d_conv0): Conv1d(1, 50, kernel_size=(3,), stride=(1,), bias=False)
(d_convs): ModuleList(
(0): Conv1d(50, 50, kernel_size=(3,), stride=(1,), bias=False)
(1): Conv1d(50, 50, kernel_size=(3,), stride=(1,), bias=False)
(2): Conv1d(50, 50, kernel_size=(3,), stride=(1,), bias=False)
(3): Conv1d(50, 50, kernel_size=(3,), stride=(1,), bias=False)
(4): Conv1d(50, 50, kernel_size=(3,), stride=(1,), bias=False)
(5): Conv1d(50, 50, kernel_size=(3,), stride=(1,), bias=False)
(6): Conv1d(50, 50, kernel_size=(3,), stride=(1,), bias=False)
(7): Conv1d(50, 50, kernel_size=(3,), stride=(1,), bias=False)
(8): Conv1d(50, 50, kernel_size=(3,), stride=(1,), bias=False)
(9): Conv1d(50, 50, kernel_size=(3,), stride=(1,), bias=False)
)
(pools): ModuleList(
(0): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(1): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(2): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(3): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(4): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(5): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(6): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(7): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(8): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(9): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(encode_conv): Conv1d(50, 50, kernel_size=(3,), stride=(1,), bias=False)
(u_convs): ModuleList(
(0): Conv1d(50, 50, kernel_size=(3,), stride=(1,), bias=False)
(1): Conv1d(100, 50, kernel_size=(3,), stride=(1,), bias=False)
(2): Conv1d(100, 50, kernel_size=(3,), stride=(1,), bias=False)
(3): Conv1d(100, 50, kernel_size=(3,), stride=(1,), bias=False)
(4): Conv1d(100, 50, kernel_size=(3,), stride=(1,), bias=False)
(5): Conv1d(100, 50, kernel_size=(3,), stride=(1,), bias=False)
)
(last_conv): Conv1d(100, 100, kernel_size=(3,), stride=(1,))
)
(recurrenter): Recurrenter(
(midi_x2h): GRU(100, 11, batch_first=True, bidirectional=True)
(midi_h2hh): GRU(22, 11, batch_first=True)
(midi_hh2y): GRU(1, 1, bias=False, batch_first=True)
)
(double_sparsemax): MultiplySparsemax(
(sparsemax_inst): Sparsemax()
(sparsemax_time): Sparsemax()
)
(zero_inserter): ZeroInserter()
(synthesizer): FastDrumSynthesizer()
(mixer): Mixer()
)
NUM_PARAM overall: 203869
unet: 195250
recurrenter: 8619
sparsemaxs: 0
synthesizer: 0
UM_PARAM overall: 203869
unet: 195250
recurrenter: 8619
sparsemaxs: 0
synthesizer: 0
..as well as training details..
PseudoCQT init with fmin:32, 12, bins, 12 bins/oct, win_len: 16384, n_fft:16384, hop_length:64
PseudoCQT init with fmin:65, 12, bins, 12 bins/oct, win_len: 8192, n_fft:8192, hop_length:64
PseudoCQT init with fmin:130, 12, bins, 12 bins/oct, win_len: 4096, n_fft:4096, hop_length:64
PseudoCQT init with fmin:261, 12, bins, 12 bins/oct, win_len: 2048, n_fft:2048, hop_length:64
PseudoCQT init with fmin:523, 12, bins, 12 bins/oct, win_len: 1024, n_fft:1024, hop_length:64
PseudoCQT init with fmin:1046, 12, bins, 12 bins/oct, win_len: 512, n_fft:512, hop_length:64
PseudoCQT init with fmin:2093, 12, bins, 12 bins/oct, win_len: 256, n_fft:256, hop_length:64
PseudoCQT init with fmin:4000, 12, bins, 12 bins/oct, win_len: 128, n_fft:128, hop_length:64
item check-points after this..: [128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072, 262144, 524288, 1048576, 2097152, 4194304]
total 8388480 n_items to train!
..then the training will start..
c1mae:5.53 c2mae:4.39 c3mae:2.95 c4mae:3.19 c5mae:2.22 c6mae:1.90 c7mae:2.14 c8mae:2.26: 100%|βββββββββββββββββββββββββββββββββββ| 1/1 [00:25<00:00, 25.03s/it]
Troubleshooting
Install MKL for pytorch FFT
In case you face this error,
RuntimeError: fft: ATen not compiled with MKL support
As stated here, this is an issue of MKL library installation. A quick solution is to use Conda. Otherwise you should install Interl MKL manually.
In some cases, if Pytorch was once built without MKL, it might not able to find later-installed MKL. You should try to remove the cache of pip/conda. Or just make a new environment.
Requirement detail
These are the exact versions I used for the dependency.
Python==3.7.3
Cython==0.29.6
cython==0.29.6
numpy==1.16.2
librosa==0.6.2
torch==1.0.0
torchvision==0.2.1
madmom==0.16.1
matplotlib==2.2.0
tqdm==4.31.1
mir_eval==0.5
Citation
@inproceedings{choi2019deep,
title={Deep Unsupervised Drum Transcription},
author={Choi, Keunwoo and Cho, Kyunghyun},
booktitle={Proceedings of the International Society for Music Information Retrieval Conference (ISMIR), Delft, Netherland},
year={2019}
}