• Stars
    star
    228
  • Rank 175,267 (Top 4 %)
  • Language
    Python
  • License
    MIT License
  • Created almost 9 years ago
  • Updated about 7 years ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Fully differentiable deep-neural decision forest in tensorflow

Fully Differentiable Deep Neural Decision Forest

DOI

This repository contains a simple modification of the deep-neural decision forest [Kontschieder et al.] in TensorFlow. The modification allows joint optimization of the decision nodes and leaf nodes which theoretically should speed up the training (haven't verified).

Motivation:

Deep Neural Deicision Forest, ICCV 2015, proposed an interesting way to incorporate a decision forest into a neural network.

The authors proposed incorporating the terminal nodes of a decision forest as static probability distributions and routing probabilities using sigmoid functions. The final loss is defined as the usual cross entropy between ground truth and weighted average of the terminal probabilities (weights being the routing probabilities).

As there are two trainable parameters, the authors used alternating optimization. They first fixed the terminal node probabilities and trained the base network (routing probabilities), then, fixed the network and optimized the terminal nodes. Such alternating optimization is usually slower than joint optimization since variables that are not being optimized slow down the optimization of the other variable.

However, if we parametrize the terminal nodes using a parametric probability distribution, we can jointly train both terminal and decision nodes, and theoretically, can speed up the convergence.

This code is just a proof-of-concept that

  1. One can train both decision nodes and leaf nodes $\pi$ jointly using parametric formulation of leaf (terminal) nodes.

  2. It is easy to implement such idea in a symbolic math library.

Formulation

The leaf node probability $p \in \Delta^{n-1}$ can be parametrized using an $n$ dimensional vector $w_{leaf}$ $\exists w_{leaf}$ s.t. $p = softmax(w_{leaf})$. Thus, we can compute the gradient of $L$ w.r.t $w_{leaf}$ as well and can jointly optimize the terminal nodes as well.

Experiment

I used a simple (3 convolution + 2 fc) network for this experiment. On the MNIST, it reaches 99.1% after 10 epochs.

Slides

SDL Reading Group Slides

Reference

[Kontschieder et al.] Deep Neural Decision Forests, ICCV 2015

More Repositories

1

3D-R2N2

Single/multi view image(s) to voxel reconstruction using a recurrent neural network
Python
1,346
star
2

FCGF

Fully Convolutional Geometric Features: Fast and accurate 3D features for registration and correspondence.
Python
632
star
3

pytorch-custom-cuda-tutorial

Tutorial for building a custom CUDA function for Pytorch
Python
512
star
4

DeepGlobalRegistration

[CVPR 2020 Oral] A differentiable framework for 3D registration
Python
467
star
5

SpatioTemporalSegmentation

4D Spatio-Temporal Semantic Segmentation on a 3D video (a sequence of 3D scans)
Python
287
star
6

MakePytorchPlusPlus

How and why you want to make your pytorch CUDA/CPP extension with a Makefile
Makefile
170
star
7

knn_cuda

Fast K-Nearest Neighbor search with GPU
Cuda
141
star
8

open-ucn

The first fully convolutional metric learning for geometric/semantic image correspondences.
Python
87
star
9

pytorch_knn_cuda

K-Nearest Neighbor in Pytorch
Cuda
67
star
10

HighDimConvNets

[CVPR 2020 Oral] High-dimensional Convolutional Networks for Geometric Pattern Recognition
Python
39
star
11

gesvd

Pytorch extension for Singular Value Decompostion (SVD) with LAPACK gesvd backend
C++
28
star
12

SUN_RGBD

Reorganized SUN RGBD dataset
Shell
25
star
13

SpatioTemporalSegmentation-ScanNet

Python
22
star
14

enriching_object_detection

C++
21
star
15

CUDA-FFT-Convolution

CUDA FFT convolution
C++
14
star
16

segmentation_lecture

Python
12
star
17

python-venv-setup

Make python virtual environment setup on old servers less painful
Shell
10
star
18

MinkowskiEngineBenchmark

Python
7
star
19

mini_lseg

Python
5
star
20

PybindNumpyExample

A simple reference template for pybind11 + numpy
C++
4
star
21

env-setup

Setup my dev environment
Shell
3
star
22

dotfiles

dot files
Vim Script
2
star
23

torch_spmm

Cuda
1
star