• Stars
    star
    135
  • Rank 269,297 (Top 6 %)
  • Language
    Python
  • License
    MIT License
  • Created over 4 years ago
  • Updated 4 months ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

A PyTorch implementation of Meta-TasNet from "Meta-learning Extractors for Music Source Separation

Overall architecture.

META-LEARNING FOR MUSIC SOURCE SEPARATION

David Samuel,* Aditya Ganeshan & Jason Naradowsky
*part of this work has been done during internship at PFN

Interactive demo | Paper


We propose a hierarchical meta-learning-inspired model for music source separation in which a generator model is used to predict the weights of individual extractor models. This enables efficient parameter-sharing, while still allowing for instrument-specific parameterization. The resulting models are shown to be more effective than those trained independently or in a multi-task setting, and achieve performance comparable with state-of-the-art methods.


Brief Introduction to Music Source Separation

Given a mixed source signal, the task of source separation algorithm is to divide the signal into its original components. We test our method on music separation and specifically on the MUSDB18 dataset where the sources consist of contemporary songs and the goal is to divide them into four stems:

      🥁🦈   drums
      🎙️🐇   vocals
      🎸🦅   bass
      🎷🐍   other accompaniments

Music source separation can not only be used as a preprocessing step to other MIR problems (like sound source identification), but it can also be used more creatively: we can create backing tracks to any song for musical practice or just for fun (karaoke), we can create "smart" equilizers that are able to make a new remix, or we can separate a single instrument to better study its intricacies (guitar players can more easily determine the exact chords for example).


Spectrogram illustration.

Illustration of a separated audio signal (projected on log-scaled spectrograms). The top spectrogram shows the mixed audio that is transformed into the four separated components at the bottom. Note that we use the spectrograms just to illustrate the task — our model operates directly on the audio waveforms.


Generating Extractor Models

The key idea is to utilize a tiered architecture where a generator network "supervises" the training of the individual extractors by generating some of their parameters directly. This allows the generator to develop a dense representation of how instruments relate to each other as it pertains to the task, and to utilize their commonalities when generating each extractor.

Our model is based on Conv-TasNet, a time domain-based approach to speech separation comprising three parts:

  1. an encoder which applies a 1-D convolutional transform to a segment of the mixture waveform to produce a high-dimensional representation
  2. a masking function which calculates a multiplicative function which identifies a targeted area in the learned representation
  3. a decoder (1-D inverse convolutional layer) which reconstructs the separated waveform for the target source.

The masking network is of particular interest, as it contains the source-specific masking information; the encoder and decoder are source-agnostic and remain fixed for separation of all sources.


Multi-stage Architecture

Despite the data's higher sampling rate (44kHz), we find that models trained using lower sampling rates are more effective despite the loss in resolution. We therefore propose a multi-stage architecture to leverage this strength while still fundamentally predicting high resolution audio and use three stages with 8, 16 and 32kHz sampling rates.


Multi-stage architecture.

Illustration of the multi-stage architecture. The resolution of the estimated signal is progressively enhanced by utilizing information from previous stages. The encoders increase the stride s to preserve the same time dimension T'. Note that the masking TCN is still generated (not included in the illustration).


Results

  • signal-to-distortion ratio (SDR) evaluated with BSSEval v4
  • results are in dB, higher is better (median of frames, median of tracks)
  • methods annotated with “*” use the audio directly, without a spectrogram sidestep


How to Run

  1. First, you have to download the MUSDB18 dataset and run the data generator to resample to music stems and save them as numpy arrays: python3 data_generator.py --musdb_path path/to/the/downloaded/dataset.

  2. After creating the dataset, you can start the training by running python3 train.py. Please note that this configuration was trained on 2 Nvidia V100 GPUs so you need ~64 GB of GPU memory to train with the default batch size.

  3. Finally, you can evaluate the model by running python3 evaluate.py --model_dir directory --musdb_path path/to/the/downloaded/dataset.


Interactive Demo

You can try an interactive demo of the pretrained model in Google Colab notebook.


Pretrained Model

A pretrained model on the MUSDB18 dataset can be downloaded from here. After downloading, load the model by the following Python lines. An example usage of the pretrained model for separation can be seen in the aforementioned Google Colab notebook.

state = torch.load("best_model.pt")  # load checkpoint
network = MultiTasNet(state["args"]).to(device)  # initialize the model
network.load_state_dict(state['state_dict'])  # load the pretrained weights

Cite

@inproceedings={meta-tasnet:2020,
    title={Meta-learning Extractors for Music Source Separation},
    author={David Samuel and Aditya Ganeshan and Jason Naradowsky},
    booktitle={IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)},
    pages={816-820},
    year={2020},
}

License

MIT License

More Repositories

1

sngan_projection

GANs with spectral normalization and projection discriminator
Python
1,079
star
2

chainer-gan-lib

Chainer implementation of recent GAN variants
Python
407
star
3

xfeat

Flexible Feature Engineering & Exploration Library using GPUs and Optuna.
Python
369
star
4

chainer-gogh

Python
302
star
5

menoh

Menoh: fast DNN inference library with multiple programming language support
C++
279
star
6

pfhedge

PyTorch-based framework for Deep Hedging
Python
249
star
7

contextual_augmentation

Contextual augmentation, a text data augmentation using a bidirectional language model.
Python
193
star
8

distilled-feature-fields

Python
178
star
9

nips17-adversarial-attack

Submission to Kaggle NIPS'17 competition on adversarial examples (non-targeted adversarial attack track)
Python
146
star
10

FSCS

Fast Soft Color Segmentation
Python
134
star
11

chainer-pix2pix

chainer implementation of pix2pix
Python
131
star
12

k8s-cluster-simulator

Kubernetes cluster simulator for evaluating schedulers.
Go
124
star
13

chainer-compiler

Experimental toolchain to compile and run Chainer models
Python
112
star
14

graph-nvp

GraphNVP: An Invertible Flow Model for Generating Molecular Graphs
Python
91
star
15

autogbt-alt

An experimental Python package that reimplements AutoGBT using LightGBM and Optuna.
Python
82
star
16

chainer-stylegan

Chainer implementation of Style-based Generator
Python
79
star
17

tgan

The implementation of Temporal Generative Adversarial Nets with Singular Value Clipping
Python
78
star
18

tgan2

The official implementation of "Train Sparsely, Generate Densely: Memory-efficient Unsupervised Training of High-resolution Temporal GAN"
Python
76
star
19

deep-table

Python
76
star
20

git-ghost

Synchronize your working directory efficiently to a remote place without committing the changes.
Go
73
star
21

chainer-graph-cnn

Chainer implementation of 'Convolutional Neural Networks on Graphs with Fast Localized Spectral Filtering' (https://arxiv.org/abs/1606.09375)
Python
67
star
22

meta-fuse-csi-plugin

A CSI plugin for All FUSE implementations
Go
67
star
23

surface-aligned-nerf

Python
65
star
24

bayesgrad

BayesGrad: Explaining Predictions of Graph Convolutional Networks
Jupyter Notebook
62
star
25

japanese-lm-fin-harness

Japanese Language Model Financial Evaluation Harness
Shell
62
star
26

torch-dftd

pytorch implementation of dftd2 & dftd3
Python
60
star
27

alertmanager-to-github

This receives webhook requests from Alertmanager and creates GitHub issues.
Go
50
star
28

kaggle-lyft-motion-prediction-4th-place-solution

Kaggle Lyft Motion Prediction for Autonomous Vehicles 4th place solution
Python
48
star
29

einconv

Python
46
star
30

chainer-segnet

SegNet implementation & experiments in Chainer
Python
42
star
31

TabCSDI

A code for the NeurIPS 2022 Table Representation Learning Workshop paper: "Diffusion models for missing value imputation in tabular data"
Python
41
star
32

gcp-workload-identity-federation-webhook

This webhook is for mutating pods that will require GCP Workload Identity Federation access from Kubernetes Cluster.
Go
39
star
33

chainer-trt

Chainer x TensorRT
C++
34
star
34

hyperbolic_wrapped_distribution

Python
32
star
35

picking-instruction

PFN Picking Instructions for Commodities Dataset (PFN-PIC) including images, bounding boxes and text instructions.
31
star
36

pftaskqueue

pftaskqueue: Lightweight task queue tool
Go
30
star
37

multi-stage-blended-diffusion

Python
30
star
38

NoTransactionBandNetwork

Minimal implementation and experiments of "No-Transaction Band Network: A Neural Network Architecture for Efficient Deep Hedging".
Jupyter Notebook
30
star
39

capg

Implementation of clipped action policy gradient (CAPG) with PPO and TRPO
Python
29
star
40

charge_transfer_nnp

Graph neural network potential with charge transfer
Python
28
star
41

node-operation-controller

Kubernetes controller for automated Node operations
Go
26
star
42

menoh-ruby

Ruby binding for Menoh DNN inference library
C
26
star
43

optuna-book

Jupyter Notebook
26
star
44

chainer-ADDA

Adversarial Discriminative Domain Adaptation in Chainer
Python
24
star
45

RJT-RL

RJT-RL: De novo molecular design using a Reversible Junction Tree and Reinforcement Learning
Python
23
star
46

superpixel-align

Official implementation of "Minimizing Supervision for Free-space Segmentation" paper
Jupyter Notebook
23
star
47

label-efficient-brain-tumor-segmentation

Python
21
star
48

chainer-disentanglement-lib

Unsupervised Disentanglement Representation Learning in Chainer
Python
21
star
49

vat_nmt

Implementation of "Effective Adversarial Regularization for Neural Machine Translation", ACL 2019
Python
21
star
50

allreduce-proto

A prototype implementation of AllReduce collective communication routine.
C++
20
star
51

Chainer-DeepFill

Python
19
star
52

pfneumonia

Repo for RSNA pneumonia open-source
Python
18
star
53

chainer-LSGAN

Least Squares Generative Adversarial Network implemented in Chainer
Python
18
star
54

KDD-Cup-AutoML-5

KDD Cup 2019 AutoML Track 5th solution
Python
18
star
55

step-wise-chemical-synthesis-prediction

A GGNN-GWM based step-wise framework for Chemical Synthesis Prediction
Python
17
star
56

ATPG4SV

A prototype of Concolic Testing engine for SystemVerilog, developed as part of PFN summer internship 2018.
OCaml
16
star
57

menoh-sharp

C# binding for Menoh DNN inference library
C#
15
star
58

BMI219-2017-ProteinFolding

UCSF BMI219 Deep Learning (2017), Coding example (Prediction of protein folding with RNN and CNN)
Python
15
star
59

go-menoh

Golang binding for Menoh DNN inference library
Go
14
star
60

hierarchical-molecular-learning

Implementation of "Semi-supervised learning of hierarchical representations of molecules using neural message passing" (arXiv:1711.10168)
Python
14
star
61

kaggle-alaska2-3rd-place-solution

3rd place solution for ALASKA2 Image Steganalysis on Kaggle
Python
13
star
62

menoh-rs

Rust binding for Menoh
Rust
13
star
63

menoh-haskell

Haskell binding for Menoh DNN inference library
Jupyter Notebook
12
star
64

chainer-differentiable-mpc

Differentiable MPC in Chainer, developed as part of PFN summer internship 2019.
Python
12
star
65

asdf-clusterctl

clusterctl plugin for the asdf version manager
Shell
12
star
66

GenerRNA

Python
11
star
67

Deep_visuo-tactile_learning_ICRA2019

11
star
68

menoh-java

Building a Deep Neural Network (DNN) application in Java
Java
11
star
69

treewidth-prediction

Prediction of Treewidth using Graph Neural Network, developed as part of PFN summer internship 2019.
Jupyter Notebook
10
star
70

pml

A ML-like programming language with type-based probabilistic behavior specification, developed as part of PFN summer internship 2018.
C++
10
star
71

optuna-hands-on

Jupyter Notebook
10
star
72

kaggle-hpa-2021-7th-place-solution

7th place solution of Human Protein Atlas - Single Cell Classification on Kaggle
Python
9
star
73

chainer-ev3

Jupyter Notebook
9
star
74

chainer-robotcar-text

8
star
75

tabret

Python
8
star
76

pfmt-bench-fin-ja

pfmt-bench-fin-ja: Preferred Multi-turn Benchmark for Finance in Japanese
Python
8
star
77

batch-metaheuristics

Python
7
star
78

rp-safe-rl

Python
7
star
79

limited-gp

C++
6
star
80

recompute

Python
6
star
81

piekd

This is the official implementation of Periodic Intra-Ensemble Knowledge Distillation (PIEKD).
Python
6
star
82

differentiable-ray-sampling

Jupyter Notebook
6
star
83

BMI219-2017-DeepQSAR

UCSF BMI219 Deep Learning (2017), Coding example (QSAR with Deep multitask learning)
Python
5
star
84

chainer-capsnet

CapsNet implemented in Chainer
Python
5
star
85

node-menoh

NodeJS binding for Menoh DNN inference library
JavaScript
5
star
86

chainer-formulanet

Chainer implementation of FormulaNet
Python
5
star
87

plamo-examples

5
star
88

pocket_detection

Pocket detection
Python
5
star
89

ssdrl

Python
4
star
90

cg-transfer

Python
4
star
91

head_model

Python
4
star
92

Finance_data_augmentation_ICAIF2022

Jupyter Notebook
3
star
93

echainer

Elastic Chainer prototype
Python
3
star
94

robust_estimation

repository for robust estimation research
Python
3
star
95

Invisible_marker_IROS2020

3
star
96

ex_matgl

MatGL-based neural network potential that computes excited state energies and forces
Python
3
star
97

timesfm_fin

Python
3
star
98

nms-comp

Neural Multi-scale Compression
Python
2
star
99

transport-control-socket

C++
2
star
100

unsupervised_segmental_empirical_ODM

Python
2
star