• Stars
    star
    218
  • Rank 181,805 (Top 4 %)
  • Language
    Python
  • License
    BSD 3-Clause "New...
  • Created about 6 years ago
  • Updated about 6 years ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Differentiable RANSAC: Learning Robust Line Fitting

Differentiable RANSAC: Learning Robust Line Fitting

Introduction

This code illustrates the principles of differentiable RANSAC (DSAC) on a simple toy problem of fitting lines to noisy, synthetic images.

Input and desired output.

Left: Input image. Right: Ground truth line.

We solve this task by training a CNN which predicts a set of 2D points within the image. We fit our desired line to these points using RANSAC.

DSAC line fitting.

Left: Input image. Center: Points predicted by a CNN. Right: Line (blue) fitted to the predictions.

Ideally, the CNN would place all its point predictions on the image line segment. But because RANSAC is robust to outlier points, the CNN may choose to allow some erroneous predictions in favor of overall accuracy.

We train the CNN end-to-end from scratch to minimize the deviation between the (robustly) fitted line and the ground truth line.

Running the Code

Just execute python main.py to start a training run with the standard settings. Running python main.py -h will list all parameter options for playing around.

The code generates training data on the fly, and trains two CNNs in parallel. The first CNN predicts a set of 2D points to which the output line is fitted using DSAC. The second CNN is a baseline where the line parameters are predicted directly, i.e. without DSAC.

In a specified interval during training, both CNNs are tested on a fixed validation set, and the visualization of the predictions is stored in an image such as the following:

Training output.

Left: Validation inputs. Center: DSAC estimates with dots marking the CNN prediction, blue the fitted line and green the ground truth line. Green borders mark accurate line predictions, red boxes mark inaccurate line predictions (there is a threshold parameter). Right: Predictions of the baseline CNN (direct prediction of line parameters).

Dependencies

This code requires the following packages, and was tested with the package version in brackets.

pytorch (0.5.0), torchvision (0.2.1), scikit-image (0.14.0)

Training Speed

Depending on your system specification, one training iteration (with the standard batch size of 32) can take more than one second. This might seem excessive for a simple toy problem. Note that this code is designed for educative clarity rather than speed. The whole DSAC portion of training runs in native Python on a single CPU core, and backpropagation relies soley on standard PyTorch autograd. In any production setting, one would write a C++/CUDA extension encapsulating DSAC for a huge runtime boost. See for example our camera localization pipelines which utilize DSAC here and here.

How Does It Work?

Vanilla RANSAC works by creating a set of model hypotheses (line hypotheses in our case), scoring them e.g. by inlier counting, and selecting the best one.

DSAC is based on the idea of making hypothesis selection a probabilistic action. The probability of selecting a hypothesis increases with its score (e.g. inlier count). Training the CNN aims at minimizing the expected loss of the selected hypothesis.

Training output.

More details and a formal description can be found in the papers referenced at the end of this document.

In a nutshell, the training process works like this:

  1. CNN predicts 2D points
  2. sample line hypotheses by choosing random pairs of points
  3. score hypotheses by soft inlier counting, and calculate selection probabilities
  4. refine hypotheses by re-fitting them to their soft inliers
  5. calculate expected loss of refined hypotheses w.r.t. selection probabilities
  6. backprob, update CNN, repeat

Loss Function

For this toy problem, we are interested in observing visually nicely aligned lines rather then the nominal error in line parameters. We thus measure the maximum distance between the predicted line and ground truth within the image, and aim at minimizing this distance as our loss function.

Loss function.

Red arrows mark the error between ground truth line (green) and estimated line (blue) that we try to minimize.

Code Structure

main.py Main script that handles the training loop.

dsac.py Encapsulates robust, differentiable line fitting with DSAC (sampling hypotheses, scoring, refinement, expected loss).

line_dataset.py Generates random, noisy input images with associated ground truth parameters. Also includes functions for visualizing predictions.

line_loss.py Loss function used to compare predicted and ground truth lines.

line_nn.py Definition of the CNN architecture which supports prediction of 2D points or direct regression of line parameters.

Publications

The following paper introduced DSAC for camera localization (paper link).

@inproceedings{brachmann2017dsac,
  title={{DSAC}-{Differentiable RANSAC} for Camera Localization},
  author={Brachmann, Eric and Krull, Alexander and Nowozin, Sebastian and Shotton, Jamie and Michel, Frank and Gumhold, Stefan and Rother, Carsten},
  booktitle={CVPR},
  year={2017}
}

This code uses a soft inlier count instead of a learned scoring function, as suggested in the following paper (paper link).

@inproceedings{brachmann2018lessmore,
  title={Learning Less is More-{6D} Camera Localization via {3D} Surface Regression},
  author={Brachmann, Eric and Rother, Carsten},
  booktitle={CVPR},
  year={2018}
}

Please cite one of these papers if you use DSAC or parts of this code in your own work.

More Repositories

1

FrEIA

Framework for Easily Invertible Architectures
Python
761
star
2

ControlNet-XS

Python
433
star
3

ngransac

Neural-Guided RANSAC for Estimating Epipolar Geometry from Sparse Correspondences
Python
305
star
4

LessMore

Learning Less is More - 6D Camera Localization via 3D Surface Regression
C++
258
star
5

dsacstar

DSAC* for Visual Camera Re-Localization (RGB or RGB-D)
C++
243
star
6

conditional_INNs

Code for the paper "Guided Image Generation with Conditional Invertible Neural Networks" (2019)
Python
93
star
7

analyzing_inverse_problems

Code for the paper "Analyzing inverse problems with invertible neural networks." (2018)
Jupyter Notebook
82
star
8

esac

ESAC - Expert Sample Consensus Applied To Camera Re-Localization
C++
68
star
9

ngdsac_camreloc

Neural-Guided, Differentiable RANSAC for Camera Re-Localization (NG-DSAC++)
C++
48
star
10

IB-INN

Code for the paper "Training Normalizing Flows with the Information Bottleneck for Competitive Generative Classification" (2020)
Python
42
star
11

FFF

Free-form flows are a generative model training a pair of neural networks via maximum likelihood
Jupyter Notebook
34
star
12

ngdsac_horizon

Neural-Guided, Differentiable RANSAC for Horizon Line Estimation
Python
33
star
13

GIN

Code for the paper "Disentanglement by Nonlinear ICA with General Incompressible-flow Networks (GIN)" (2020)
Python
31
star
14

HINT

Code for the research paper "HINT: Hierarchical Invertible Neural Transport for Density Estimation and Bayesian Inference".
Python
19
star
15

trustworthy_GCs

Materials for the paper https://arxiv.org/pdf/2007.15036.pdf
14
star
16

3dcv-students

Exercises for the lecture "Computer Vision: 3D Reconstruction"
Jupyter Notebook
12
star
17

inn_toy_data

Code for artificial toy data sets used to evaluate (conditional) invertible neural networks and related methods
Python
11
star
18

Coupling-Universality

Jupyter Notebook
10
star
19

ML-Tutorial

Introduction to Machine Learning using scikit-learn and PyTorch
Jupyter Notebook
8
star
20

gmbench

Code for โ€œA Comparative Study of Graph Matching Algorithms in Computer Visonโ€ (ECCV 2022)
HTML
8
star
21

MLAE

Maximum Likelihood Training of Autoencoders
Jupyter Notebook
8
star
22

libct

A fast approximate primal-dual solver for tracking-by-assignment cell-tracking problems
C++
5
star
23

libmpopt

Library for Message Passing Optimization Techniques
C++
4
star
24

combilp

Exact MAP-Inference by Confining Combinatorial Search with LP Relaxation
Python
4
star
25

Gaussianization-Bound

Code to paper "On the Convergence Rate of Gaussianization with Random Rotations"
Python
3
star