TTA wrapper
Test time augmnentation wrapper for keras image segmentation and classification models.
Description
How it works?
Wrapper add augmentation layers to your Keras model like this:
Input
| # input image; shape 1, H, W, C
/ / / \ \ \ # duplicate image for augmentation; shape N, H, W, C
| | | | | | # apply augmentations (flips, rotation, shifts)
your Keras model
| | | | | | # reverse transformations
\ \ \ / / / # merge predictions (mean, max, gmean)
| # output mask; shape 1, H, W, C
Output
Arguments
h_flip
- bool, horizontal flip augmentationv_flip
- bool, vertical flip augmentationrotataion
- list, allowable angles - 90, 180, 270h_shift
- list of int, horizontal shift augmentation in pixelsv_shift
- list of int, vertical shift augmentation in pixelsadd
- list of int/float, additive factor (aug_image = image + factor)mul
- list of int/float, additive factor (aug_image = image * factor)contrast
- list of int/float, contrast adjustment factor (aug_image = (image - mean) * factor + mean)merge
- one of 'mean', 'gmean' and 'max' - mode of merging augmented predictions together
Constraints
- model has to have 1
input
and 1output
- inference
batch_size == 1
- image
height == width
ifrotation
augmentation is used
Installation
- PyPI package:
$ pip install tta-wrapper
- Latest version:
$ pip install git+https://github.com/qubvel/tta_wrapper/
Example
from keras.models import load_model
from tta_wrapper import tta_segmentation
model = load_model('path/to/model.h5')
tta_model = tta_segmentation(model, h_flip=True, rotation=(90, 270),
h_shift=(-5, 5), merge='mean')
y = tta_model.predict(x)