• Stars
    star
    183
  • Rank 208,891 (Top 5 %)
  • Language
    Python
  • License
    MIT License
  • Created almost 6 years ago
  • Updated over 1 year ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

A pytorch package for non-negative matrix factorization.

Non-negative Matrix Fatorization in PyTorch

build Upload Python Package codecov Documentation Status PyPI version

PyTorch is not only a good deep learning framework, but also a fast tool when it comes to matrix operations and convolutions on large data. A great example is PyTorchWavelets.

In this package I implement NMF, PLCA and their deconvolutional variations in PyTorch based on torch.nn.Module, so the models can be moved freely among CPU/GPU devices and utilize parallel computation of cuda. We also utilize the computational graph from torch.autograd to derive updated coefficients so the amount of codes is reduced and easy to maintain.

Modules

NMF

Basic NMF and NMFD module minimizing beta-divergence using multiplicative update rules.

The interface is similar to sklearn.decomposition.NMF with some extra options.

  • NMF: Original NMF algorithm.
  • NMFD: 1-D deconvolutional NMF algorithm.
  • NMF2D: 2-D deconvolutional NMF algorithm.
  • NMF3D: 3-D deconvolutional NMF algorithm.

PLCA

Basic PLCA and SIPLCA module using EM algorithm to minimize KL-divergence between the target distribution and the estimated distribution.

  • PLCA: Original PLCA (Probabilistic Latent Component Analysis) algorithm.
  • SIPLCA: Shift-Invariant PLCA algorithm (similar to NMFD).
  • SIPLCA2: 2-D deconvolutional SIPLCA algorithm.
  • SIPLCA3: 3-D deconvolutional SIPLCA algorithm.

Usage

Here is a short example of decompose a spectrogram using deconvolutional NMF:

import torch
import librosa
from torchnmf.nmf import NMFD
from torchnmf.metrics import kl_div

y, sr = librosa.load(librosa.util.example_audio_file())
y = torch.from_numpy(y)
windowsize = 2048
S = torch.stft(y, windowsize, 
               window=torch.hann_window(windowsize),
               return_complex=True).abs().cuda()
S = S.unsqueeze(0)

R = 8   # number of components
T = 400 # size of convolution window

net = NMFD(S.shape, rank=R, T=T).cuda()
# run extremely fast on gpu
net.fit(S)      # fit to target matrix S
V = net()
print(kl_div(V, S))        # KL divergence to S

A more detailed version can be found here. See our documentation to find out more usage of this package.

Compare to sklearn

The barchart shows the time cost per iteration with different beta-divergence. It shows that pytorch-based NMF has a much more constant process time across different beta values, which can take advantage when beta is not 0, 1, or 2. This is because our implementation use the same computational graph regardless which beta-divergence are we minimizing. It runs even faster when computation is done on GPU. The test is conducted on a Acer E5 laptop with i5-7200U CPU and GTX 950M GPU.

Installation

pip install torchnmf

Requirements

  • PyTorch
  • tqdm

Tips

  • If you notice significant slow down when operating on CPU, please flush denormal numbers by torch.set_flush_denormal(True).

TODO

  • Support sparse matrix target (only on NMF module).
  • Regularization.
  • NNDSVD initialization.
  • 2/3-D deconvolutional module.
  • PLCA.
  • Documentation.
  • ipynb examples.
  • Refactor PLCA module.

More Repositories

1

diffwave-sr

Jupyter Notebook
77
star
2

music-spectrogram-diffusion-pytorch

Python
70
star
3

spectrogram-inversion

spectrogram inversion tools in PyTorch. Documentation: https://spectrogram-inversion.readthedocs.io
Python
43
star
4

constant-memory-waveglow

PyTorch implementation of NVIDIA WaveGlow with constant memory cost.
Python
36
star
5

pytorch_FFTNet

A pytorch implementation of FFTNet.
Python
36
star
6

music-demixing-challenge-ismir-2021-entry

The training code for the 4th place model at MDX 2021 leaderboard A.
Python
34
star
7

variational-diffwave

Python
27
star
8

duet-svs-diffusion

Python
26
star
9

wavenet-like-vocoder

Basic wavenet and fftnet vocoder model.
Python
20
star
10

eva

A screaming vocal samples dataset.
Python
14
star
11

danna-sep

Python
13
star
12

kazane

Simple sinc interpolation in PyTorch.
Python
11
star
13

ar-diffwave

Python
6
star
14

pytorch-wise-ale

Jupyter Notebook
6
star
15

translation-invariant

A pytorch implementation of translation-invariant network using in music transcription.
Python
4
star
16

small_music_production_tools

Python
4
star
17

phase-unwrapping

try out different PU algorithms.
Jupyter Notebook
4
star
18

DL_HW1

back propogation
Python
3
star
19

building-nn-with-numpy

numpy lenet implementation.
Python
3
star
20

MIR_HW2

Python
3
star
21

guitar-plucking-estimation

Python
2
star
22

ML_HW3

My implementation of homework 3 for the Machine Learning class in NCTU (course number 5088).
Python
2
star
23

ML_HW1

My implementation of homework 1 for the Machine Learning class in NCTU (course number 5088).
Python
2
star
24

DL_HW2

Python
2
star
25

IIAP_dtw

multicomponent dtw.
C++
2
star
26

hrtf-notebooks

Some HRTF experiments I made.
Jupyter Notebook
2
star
27

yoyololicon

2
star
28

DL_HW3

Python
2
star
29

GTR_Midi_Capture

C++
2
star
30

SignalAndSystem_Special_Topic

just write this to convert the csv files i get to txt, so that i can plot them in gnuplot.
C++
2
star
31

2016Fall_NP_HW3

C++
1
star
32

2016Fall_OS_HW2

C++
1
star
33

ape-examples

C++
1
star
34

Multimedia_hw1

HTML
1
star
35

DIP_HW2

C++
1
star
36

simple_xor_nn_model

This is a small work I did on a DL class I audit this semester.
Python
1
star
37

MIR_HW1

Jupyter Notebook
1
star
38

DIP_HW3

C++
1
star
39

2016Fall_OS_HW5

C++
1
star
40

IML_HW2

My implementation of homework 2 for the Introduction to Machine Learning class in NCTU (course number 1181).
Python
1
star
41

multi-layered-cepstrum

A recursive Fourier transform (FT) structure that suppresses convolutional noise.
Python
1
star
42

audio_programming_examples

python audio programming examples.
Python
1
star
43

ML_HW2

My implementation of homework 2 for the Machine Learning class in NCTU (course number 5088).
Python
1
star
44

2016Fall_OS_HW3

C++
1
star
45

IML_HW1

My implementation of homework 1 for the Introduction to Machine Learning class in NCTU (course number 1181).
Python
1
star
46

2016Fall_OS_HW4

C++
1
star
47

IML_HW3

My implementation of homework 3 for the Introduction to Machine Learning class in NCTU (course number 1181).
Python
1
star
48

2016Fall_NP_HW2

C++
1
star
49

2016Fall_NP_Midterm

C++
1
star
50

bayesian-pitch-tracking-python

statistical pitch tracking
Jupyter Notebook
1
star
51

2016Fall_OS_HW7

C++
1
star
52

2016Fall_NP_HW1

C++
1
star
53

2016Fall_OS_HW1

C++
1
star
54

2016Fall_OS_HW6

C++
1
star
55

bela-hrir-convolver

Jupyter Notebook
1
star