• Stars
    star
    190
  • Rank 203,739 (Top 5 %)
  • Language
    Python
  • License
    MIT License
  • Created over 6 years ago
  • Updated over 1 year ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

TensorFlow implementation of the SOM-VAE model as described in https://arxiv.org/abs/1806.02199

SOM-VAE model

This repository contains a TensorFlow implementation of the self-organizing map variational autoencoder as described in the paper SOM-VAE: Interpretable Discrete Representation Learning on Time Series.

If you like the SOM-VAE, you should also check out the DPSOM (paper, code), which yields better performance on many tasks.

Getting Started

These instructions will get you a copy of the project up and running on your local machine for development and testing purposes.

Prerequisites

In order to install and run the model, you will need a working Python 3 distribution as well as a NVIDIA GPU with CUDA and cuDNN installed.

Installing

In order to install the model and run it, you have to follow these steps:

  • Clone the repository, i.e. run git clone https://github.com/ratschlab/SOM-VAE
  • Change into the directory, i.e. run cd SOM-VAE
  • Install the requirements, i.e. run pip install -r requirements.txt
  • Install the package itself, i.e. run pip install .
  • Change into the code directory, i.e. cd som_vae

Now you should be able to run the code, e.g. do python somvae_train.py.

Training the model

The SOM-VAE model is defined in somvae_model.py. The training script is somvae_train.py.

If you just want to train the model with default parameter settings, you can run

python somvae_train.py

This will download the MNIST data set into data/MNIST_data/ and train on it. Afterwards, it will evaluate the trained model in terms of different clustering performance measures.

The parameters are handled using sacred. That means that if you want to run the model with a different parameter setting, e.g. a latent space dimensionality of 32, you can just call the training script like

python somvae_train.py with latent_dim=32

Per default, the script will generate time courses of linearly interpolated MNIST digits. To train on normal MNIST instead, run

python somvae_train.py with time_series=False

Note that for non-time-series training, you should also set the loss parameters gamma and tau to 0. If you want to save the model for later use, run

python somvae_train.py with save_model=True

If you want to train on Fashion-MNIST istead of normal MNIST, download the data set into data/fashion/ and run

python somvae_train.py with data_set="fashion"

For more details regarding the different model parameters and how to set them, please look at the documentation in the code and at the sacred documentation.

Hyperparameter optimization

If you want to optimize the models hyperparameters, you have to additionally install labwatch and SMAC and comment the commented out lines in somvae_train.py in. Note that you also have to run a local distribution of the MongoDB.

Train on other kinds of data

If you want to train on other types of data, you have to run the training with

python somvae_train.py with mnist=False

Moreover, you have to define the correct dimensionality in the respective input_length and input_channels parameters of the model, provide a suitable data generator in somvae_train.py and potentially change the dimensionality of the layers in somvae_model.py.

To reproduce the experiments on eICU data, please use the preprocessing pipeline from this repository: https://github.com/ratschlab/variational-psom

Authors

See also the list of contributors who participated in this project.

License

This project is licensed under the MIT License - see the LICENSE file for details

More Repositories

1

RGAN

Recurrent (conditional) generative adversarial networks for generating real-valued time series data.
Python
639
star
2

GP-VAE

TensorFlow implementation for the GP-VAE model described in https://arxiv.org/abs/1907.04155
Python
124
star
3

metagraph

Scalable annotated de Bruijn graphs for DNA indexing, alignment, and assembly
C++
110
star
4

spladder

Tool for the detection and quantification of alternative splicing events from RNA-Seq data.
Python
103
star
5

dpsom

Code associated with ACM-CHIL 21 paper 'T-DPSOM - An Interpretable Clustering Method for Unsupervised Learning of Patient Health States'
Python
66
star
6

circEWS

circEWS public code
Python
58
star
7

bnn_priors

Code for the paper "Bayesian Neural Network Priors Revisited"
Python
55
star
8

HIRID-ICU-Benchmark

Repository for the HiRID ICU Benchmark (HiB) project
Python
51
star
9

pancanatlas_code_public

Public repository containing research code for the TCGA PanCanAtlas Splicing project
Python
41
star
10

mmr

A tool for Read Multi-Mapper Resolution
C++
24
star
11

RiboDiff

RiboDiff: Tool to detect changes in translational efficiency based on ribosome footprinting data
Python
22
star
12

ncl

Code of the paper "Neighborhood Contrastive Learning Applied to Online Patient Monitoring"
Python
20
star
13

scim

Code for Universal Single-Cell Matching with Unpaired Feature Sets
Jupyter Notebook
19
star
14

SVGP-VAE

Tensorflow implementation for the SVGP-VAE model.
Python
19
star
15

repulsive_ensembles

Repo for our paper "Repulsive deep ensembles are Bayesian"
Jupyter Notebook
18
star
16

uRNN

Code for "Learning Unitary Operators with Help From u(n)", AAAI-17. (https://arxiv.org/abs/1607.04903)
Python
16
star
17

graph_annotation

Code accompanying the publication for compressed graph annotation
C++
13
star
18

pmvae

Code for pmVAE model, seen in ICML CompBio '21
Jupyter Notebook
12
star
19

tensor-sketch-alignment

Code for the paper Aligning Distant Sequences to Graphs using Long Seed Sketches.
C++
12
star
20

aestetik

AESTETIK: AutoEncoder for Spatial Transcriptomics Expression with Topology and Image Knowledge
Python
9
star
21

boosting-bbvi

Python
7
star
22

dgp-vae

Disentangled GP-VAE
Python
7
star
23

Project2020-seq-tensor-sketching

C++
7
star
24

secedo

Clustering tumor cells based on SNVs from single-cell sequencing data
C++
6
star
25

mmugl

Code repository for MMUGL: Multi-modal Graph Learning over UMLS Knowledge Graphs
Python
6
star
26

mlhc-seminar

Materials for a reading group on machine learning for healthcare and medicine.
5
star
27

oqtans_tools

Oqtans repository
C++
5
star
28

rDiff

Tests for Differential RNA Isoform Expression
C
5
star
29

clinical-embeddings

Repository for the Paper: β€žOn the Importance of Step-wise Embeddings for Heterogeneous Clinical Time-Seriesβ€œ
Python
5
star
30

counting_dbg

Lossless Indexing with Counting de Bruijn Graphs
Jupyter Notebook
5
star
31

tls

Code for paper Temporal Label Smoothing for Early Event Prediction (ICML 2023)
Python
5
star
32

easysvm

The EasySVM Toolbox based on Shogun
Python
5
star
33

genome_graph_annotation

Sparse Binary Relation Representations for Genome Graph Annotation
C++
4
star
34

sim_read_until

Simulator of an ONT device with ReadUntil gRPC support
Jupyter Notebook
3
star
35

HMSVMToolbox

The Hidden Markov SVM Toolbox
MATLAB
3
star
36

PBWT-sec

C++ implementation of PBWT-seq
C++
3
star
37

projects-2020-Neural-SVGD

Nonparametric variational inference by transporting samples along a dynamically learned trajectory.
Jupyter Notebook
3
star
38

SNBNMF-mutsig-public

Supervised Negative Binomial NMF for Mutational Signature Discovery
Python
3
star
39

oqtans

The master Oqtans repository with submodules
Shell
2
star
40

immunopepper

Code for the ImmunoPepper project
Python
2
star
41

adaptive-stepsize-boosting-bbvi

Python
2
star
42

ratschlab-common

Library of common Python code used across various projects
Jupyter Notebook
2
star
43

MiTie

The RNA-seq transcript predictor for multiple samples
C
2
star
44

mla

Scripts and data for reproducing the results of MetaGraph-MLA
Jupyter Notebook
2
star
45

projects2017-kG

Python
2
star
46

row_diff

RowDiff transform for sparsification of graph annotations
Jupyter Notebook
2
star
47

genomic-gnn

Repository for the paper "Learning Genomic Sequence Representations using Graph Neural Networks over De Bruijn Graph"
Python
2
star
48

mSplicer

Accurate splice form prediction based on discriminative learning
C++
2
star
49

tools-gogo-gadget

A tool to aggregate custom command line tools into one.
Python
2
star
50

path-fa

Code for "Probabilistic pathway based multimodal factor analysis"
Jupyter Notebook
2
star
51

simulate_spatial_transcriptomics_tool

Simulate spatial transcriptomics data
Jupyter Notebook
2
star
52

projects2020-disentangled-gpvae

Learning disentangled representations from time series.
Python
1
star
53

rQuant

The RNA-seq transcript quantifier with Bias Correction
MATLAB
1
star
54

AdaBoost

Adaboost-Reg and RBF-Network code
C
1
star
55

seqCNN

embedding sequence using a convolutional network
Jupyter Notebook
1
star
56

metagraph_paper_resources

This repository contains resources related to the manuscript describing the MetaGraph framework
Jupyter Notebook
1
star
57

hif_splicing_code_public

public repo for the code used in "Characterisation of HIF-dependent alternative isoforms in pancreatic cancer"
R
1
star
58

gromics

Collection of tools and utilities for *omics analyses
Python
1
star
59

ASP

Accurate Splice Site Predictions
C++
1
star
60

palmapper

The accurate RNA-seq mapper
C++
1
star
61

metannot

Multithreaded wavelet trie construction library
C++
1
star
62

tools-cwl-workflow-experiments

Very simple workflows to experiment with containers and cwl
Python
1
star
63

st-rep

Code of the paper "Representation learning for multi-modal spatially resolved transcriptomics data"
Jupyter Notebook
1
star