• Stars
    star
    161
  • Rank 232,123 (Top 5 %)
  • Language
    Jupyter Notebook
  • License
    MIT License
  • Created over 6 years ago
  • Updated 5 months ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

An MDN Layer for Keras using TensorFlow's distributions module

Keras Mixture Density Network Layer

Build Status MIT License DOI PyPI version

A mixture density network (MDN) Layer for Keras using TensorFlow's distributions module. This makes it a bit more simple to experiment with neural networks that predict multiple real-valued variables that can take on multiple equally likely values.

This layer can help build MDN-RNNs similar to those used in RoboJam, Sketch-RNN, handwriting generation, and maybe even world models. You can do a lot of cool stuff with MDNs!

One benefit of this implementation is that you can predict any number of real-values. TensorFlow's Mixture, Categorical, and MultivariateNormalDiag distribution functions are used to generate the loss function (the probability density function of a mixture of multivariate normal distributions with a diagonal covariance matrix). In previous work, the loss function has often been specified by hand which is fine for 1D or 2D prediction, but becomes a bit more annoying after that.

Two important functions are provided for training and prediction:

  • get_mixture_loss_func(output_dim, num_mixtures): This function generates a loss function with the correct output dimensiona and number of mixtures.
  • sample_from_output(params, output_dim, num_mixtures, temp=1.0): This functions samples from the mixture distribution output by the model.

Installation

This project requires Python 3.6+, TensorFlow and TensorFlow Probability. You can easily install this package from PyPI via pip like so:

python3 -m pip install keras-mdn-layer

And finally, import the mdn module in Python: import mdn

Alternatively, you can clone or download this repository and then install via python setup.py install, or copy the mdn folder into your own project.

Examples

Some examples are provided in the notebooks directory.

There's scripts for fitting multivalued functions, a standard MDN toy problem:

Keras MDN Demo

There's also a script for generating fake kanji characters:

kanji test 1

And finally, for learning how to generate musical touch-screen performances with a temporal component:

Robojam Model Examples

How to use

The MDN layer should be the last in your network and you should use get_mixture_loss_func to generate a loss function. Here's an example of a simple network with one Dense layer followed by the MDN.

from tensorflow import keras
import mdn

N_HIDDEN = 15  # number of hidden units in the Dense layer
N_MIXES = 10  # number of mixture components
OUTPUT_DIMS = 2  # number of real-values predicted by each mixture component

model = keras.Sequential()
model.add(keras.layers.Dense(N_HIDDEN, batch_input_shape=(None, 1), activation='relu'))
model.add(mdn.MDN(OUTPUT_DIMS, N_MIXES))
model.compile(loss=mdn.get_mixture_loss_func(OUTPUT_DIMS,N_MIXES), optimizer=keras.optimizers.Adam())
model.summary()

Fit as normal:

history = model.fit(x=x_train, y=y_train)

The predictions from the network are parameters of the mixture models, so you have to apply the sample_from_output function to generate samples.

y_test = model.predict(x_test)
y_samples = np.apply_along_axis(sample_from_output, 1, y_test, OUTPUT_DIMS, N_MIXES, temp=1.0)

See the notebooks directory for examples in jupyter notebooks!

Load/Save Model

Saving models is straight forward:

model.save('test_save.h5')

But loading requires cutom_objects to be filled with the MDN layer, and a loss function with the appropriate parameters:

m_2 = keras.models.load_model('test_save.h5', custom_objects={'MDN': mdn.MDN, 'mdn_loss_func': mdn.get_mixture_loss_func(1, N_MIXES)})

Acknowledgements

References

  1. Christopher M. Bishop. 1994. Mixture Density Networks. Technical Report NCRG/94/004. Neural Computing Research Group, Aston University. http://publications.aston.ac.uk/373/
  2. Axel Brando. 2017. Mixture Density Networks (MDN) for distribution and uncertainty estimation. Master’s thesis. Universitat Politècnica de Catalunya.
  3. A. Graves. 2013. Generating Sequences With Recurrent Neural Networks. ArXiv e-prints (Aug. 2013). https://arxiv.org/abs/1308.0850
  4. David Ha and Douglas Eck. 2017. A Neural Representation of Sketch Drawings. ArXiv e-prints (April 2017). https://arxiv.org/abs/1704.03477
  5. Charles P. Martin and Jim Torresen. 2018. RoboJam: A Musical Mixture Density Network for Collaborative Touchscreen Interaction. In Evolutionary and Biologically Inspired Music, Sound, Art and Design: EvoMUSART ’18, A. Liapis et al. (Ed.). Lecture Notes in Computer Science, Vol. 10783. Springer International Publishing. DOI:10.1007/9778-3-319-77583-8_11

More Repositories

1

creative-prediction

Creative Prediction with Neural Networks
Jupyter Notebook
23
star
2

imps

IMPSY - the Interactive Musical Prediction SYstem
Python
21
star
3

robojam

A Mixture Density RNN for generating musical touchscreen interactions.
Jupyter Notebook
12
star
4

myo-to-osc

A pure-python cross-platform solution for simply connecting Myo armbands to OSC-connected software.
Python
11
star
5

microjam

A little app for experimenting with tiny musical performance.
Swift
8
star
6

evobytebeat

Evolutionary Bytebeats
Python
8
star
7

wavmdrnn

A Mixture Density Recurrent Neural Network for generating digital audio
Python
6
star
8

ComputerMusicIntro

A seminar introduction to computer music with examples in Pure Data first given in May 2014.
Processing
6
star
9

empi

An Embodied Musical Predictive Interface using Mixture Density Networks
Python
6
star
10

SimpleAmbisonics

Some simple (and bad) Pd objects for 3D ambisonics. They don't work very well!
Pure Data
6
star
11

Studio1-Demo

Examples of some 3D sound and Ambisonic performance techniques in Studio 1 at the ANU PK Building
Pure Data
6
star
12

bela-myo-example

Example for using the Myo armband with Bela
C++
5
star
13

ExampleOSC

A basic iPhone app that shows how to use MetatoneOSC
Objective-C
4
star
14

gesture-rnn

A deep model of touch-screen ensemble musical performance
Jupyter Notebook
4
star
15

OSC-Logger

A Cocoa application for logging OSC messages to a text file. Advertises itself as a Bonjour service.
Objective-C
3
star
16

big-counter

A Big Counter in p5js for timing performances.
JavaScript
3
star
17

art-and-interaction-bibliography

Art and Interaction Bibliography - References for COMP1720/6720 at the ANU
TeX
3
star
18

microbit-v2-baremetal

Baremetal tests on microbit v2
Makefile
2
star
19

genAI-MIDI-module

A generative AI system for electronic music that runs on a raspberry pi.
Python
2
star
20

musical-control-data-prediction

A comparison of multiple models for predicting musical control data.
Jupyter Notebook
2
star
21

teslamusic

Music for Tesla Coils
Alloy
2
star
22

intelligent-musical-instruments-bib

A bibliography of intelligent musical instrument research
HTML
1
star
23

HHM-Interface

Software for an FX and control interface for HHM Ensemble using Pd and Intel Galileo
Pure Data
1
star
24

keras-transformer-mdn

Jupyter Notebook
1
star
25

fader-jam

Making generative music with sliders.
Objective-C
1
star
26

reveal-slides-to-pdf

Download reveal.js slides to a small pdf
Shell
1
star
27

chroma-template

A single-column journal article template for Markdown authoring
TeX
1
star
28

comp2300-discoserver

Discoboard emulator for COMP2300 created by @Aerijo
Rust
1
star
29

homepage

Charles' Homepage in Jekyll
HTML
1
star
30

MetaTravels

A prototype touch-screen musical instrument for iPad
Objective-C
1
star
31

lightpad-block-prediction

LightPad Block Recording and Prediction
Jupyter Notebook
1
star
32

smc-learning-resources

Learning Resources for Sound and Music Computing
1
star
33

cpm_deep_learning_course_solns

Solutions from the Udacity Deep Learning Course
Jupyter Notebook
1
star
34

melody_harmony_hmm

A Hidden Markov Model for predicting harmony from melodies.
Jupyter Notebook
1
star
35

AmpliMicTutorial

Tutorial for assembling LM386 Amplifier and DIY Microphones
1
star
36

EMS-DIYSynth

A DIY Synthesiser Workshop
Shell
1
star
37

mpl-disjoint-horiz-stacked-bar

Disjoing Horizontal Stacked Bar Charts in MatPlotLib
Jupyter Notebook
1
star
38

arduino-OSC2DMX

A simple Arduino program for controlling DMX lights over OSC
C++
1
star
39

synesthetic

Pd and Microcontroller for Synesthetic Marimba and Lighting Performances
C++
1
star
40

musical-ml-resources

A short list of resources for getting started in Musical Machine Learning
1
star