iUNets - Fully invertible U-Nets in Pytorch
This library enables highly memory-efficient training of fully-invertible U-Nets (iUNets) in 1D, 2D and 3D for use cases such as segmentation of medical images. It is based on the paper iUNets: Fully invertible U-Nets with Learnable Up- and Downsampling by Christian Etmann, Rihuan Ke & Carola-Bibiane Schönlieb.
The library can be installed via the following command:
pip install iunets
Official documentation: https://iunets.readthedocs.io.
By combining well-known reversible layers (such as additive coupling layers) with novel learnable invertible up- and downsampling operators and suitable channel splitting/concatenation, the iUNet is fully bijective. This allows for reconstructing activations instead of storing them. As such, the memory demand is (in theory) independent of the number of layers.
The following table exemplifies the memory savings that can be achieved by applying our memory-efficient gradient calculation, in contrast to the conventional backpropagation procedure. Details are found in the paper.
Depth | Conventional | Ours | Ratio |
---|---|---|---|
5 | 3.17 GB | 0.85 GB | 26.8% |
10 | 5.90 GB | 1.09 GB | 18.4% |
20 | 11.4 GB | 1.57 GB | 13.8% |
30 | 16.8 GB | 2.06 GB | 12.2% |
If you're using this code in a publication, please cite this as:
@inproceedings{etmann2020iunets,
title={iUNets: learnable invertible up-and downsampling for large-scale inverse problems},
author={Etmann, Christian and Ke, Rihuan and Sch{\"o}nlieb, Carola-Bibiane},
booktitle={2020 IEEE 30th International Workshop on Machine Learning for Signal Processing (MLSP)},
pages={1--6},
year={2020},
organization={IEEE}
}
Features
- Easily set up memory-efficient iUNets in 1D, 2D or 3D, that can be trained like any other model in Pytorch and can be used e.g. for high-dimensional segmentation.
- Highly customizable.
- Learnable, invertible and possibly anisotropic up- and downsampling.
- Orthogonal channel mixing.
- Orthogonality for channel mixing and learnable invertible up- and downsampling is enforced by efficient Lie group methods, i.e. Cayley transforms and matrix exponentials of skew-symmetric matrices.
- Quality-of-life features such as automatic, dynamic padding and unpadding (if required for invertibility), as well as model summaries.
Requirements
iUNets are powered by the following two libraries:
Example usage
A simple 2D iUNet
A version of the iUNet depicted above can be created incredibly simply. Let's say that we want 2 additive coupling layers per resolution, both in the downsampling branch (left) and the upsampling branch (right).
from iunets import iUNet
model = iUNet(
channels=(64,128,256,384),
architecture=(2,2,2,2),
dim=2
)
model.print_layout()
Output:
64-64-(32/32)--------------------------------------------------------(32/32)-64-64
---------128-128-(64/64)---------------------------------(64/64)-128-128----------
---------------------256-256-(160/96)--------(160/96)-256-256---------------------
---------------------------------384-384--384-384---------------------------------
This model can now be integrated into the normal Pytorch workflow (and in particular be used as a sub-network) just like any other torch.nn.Module, and it automatically employs the memory-efficient backpropagation.
A fully-customized 3D iUNet
While the above example shows that a simple iUNet can be created quite simply, our library also allows for a high degree of customization. Refer to the API documentation for more information.
from iunets import iUNet
from iunets.layers import create_standard_module
model = iUNet(
channels=(7,15,35,91),
dim=3,
architecture=(2,3,1,3),
create_module_fn=create_standard_module,
module_kwargs={'depth': 3},
slice_mode='double',
resampling_stride=[2,2,(1,2,2)],
learnable_resampling=True,
resampling_init='haar',
resampling_method='cayley',
disable_custom_gradient=False,
revert_input_padding=True,
padding_mode='reflect',
verbose=1
)
model.print_layout()
Output:
Could not exactly create an iUNet with channels=(7, 15, 35, 91) and
resampling_stride=[(2, 2, 2), (2, 2, 2), (1, 2, 2)]. Instead using closest
achievable configuration: channels=(7, 16, 32, 92).
Average relative error: 0.0508
7-7-(5/2)-------------------------------------------------(5/2)-7-7
------16-16-16-(12/4)------------------------(12/4)-16-16-16-------
------------------32-(9/23)------------(9/23)-32-------------------
------------------------92-92-92--92-92-92-------------------------