V-GAN in Tensorflow
This repository is Tensorflow implementation of Retinal Vessel Segmentation in Fundoscopic Images with Generative Adversarial Networks. The referenced keras code can be found here.
Improvements Compared to Keras Code
- Data augmentation is changed from off-line to online process, it solved memory limitation problem but it will slow down the training
- Add train_interval FLAGS to control training iterations between generator and discriminator, for normal GAN train_interval is 1
- The best model is saved based on the sum of the AUC_PR and AUC_ROC on validation data
- Add sampling function to check generated results to know what's going on
- Measurements are plotted on tensorboard in training process
- The code is written more structurally
Area Under the Curve (AUC), Precision and Recall (PR), Receiver Operating Characteristic (ROC)
Package Dependency
- tensorflow 1.6.0
- python 3.5.3
- numpy 1.14.2
- matplotlib 2.0.2
- pillow 5.0.0
- scikit-image 0.13.0
- scikit-learn 0.19.0
- scipy 0.19.0
Download Data
Original data file strucure was modified for convenience by Jaemin Son.
Download data from here and copy data file in the same directory with codes file as following Directory Hierarchy.
Directory Hierarchy
.
βββ codes
β βββ dataset.py
β βββ evaluation.py
β βββ main.py
β βββ model.py
β βββ solver.py
β βββ TensorFlow_utils.py
β βββ utils.py
βββ data
β βββ DRIVE
β βββ STARE
βββ evaluation (get after running evaluation.py)
β βββ DRIVE
β βββ STARE
βββ results
β βββ DRIVE
β βββ STARE
codes: source codes
data: original data. File hierarchy is modified for convenience.
evaluation: quantitative and qualitative evaluation. (get after running evaluation.py)
results: results of other methods. These image files are retrieved from here.
Training
Move to codes folder and run main.py
python main.py --train_interval=<int> --ratio_gan2seg=<int> --gpu_index=<int> --discriminator=[pixel|patch1|patch2|image] --batch_size=<int> --dataset=[DRIVE|STARE] --is_test=False
- models will be saved in './codes/{}/model_{}_{}_{}'.format(dataset, disriminator, train_interval, batch_size)' folder, e.g., './codes/STARE/model_image_100_1' folder.
- smapled images will be saved in './codes/{}/sample__{}_{}_{}'.format(dataset, disriminator, train_interval, batch_size)', e.g., './codes/STARE/sample_image_100_1' folder.
Arguments
train_interval: training interval between discriminator and generator, default: 1
ratio_gan2seg: ratio of gan loss to seg loss, default: 10
gpu_index: gpu index, default: 0
discriminator: type of discriminator [pixel|patch1|patch2|image], default: image
batch_size: batch size, default: 1
dataset: dataset name [DRIVE|STARE], default: STARE
is_test: set mode, default: False
learning_rate: initial learning rate for Adam, default: 2e-4
beta1: momentum term of Adam, default: 0.5
iters: number of iterations, default: 50000
print_freq: print loss information frequency, default: 100
eval_freq: evaluation frequency on validation data, default: 500
sample_freq: sample generated image frequency, default: 200
checkpoint_dir: models are saved here, default: './checkpoints'
sample_dir: sampled images are saved here, default: './sample'
test_dir: test images are saved here, default: './test'
Test
python main.py --is_test=True --discriminator=[pixel|patch1|patch2|image] --batch_size=<int> --dataset=[DRIVE|STARE]
- Outputs of inferece are generated in 'seg_result_{}_{}_{}'.format(discriminator, train_interval, batch_size) folder, e.g., './codes/STARE/seg_result_image_100_1' folder.
- Make sure model already trained with defined dataset, discriminator, training interval, and batch size.
Evaluation
Note: Copy predicted vessel images to the ./results/[DRIVE|STARE]/V-GAN folder
python evaluation.py
Results are generated in evaluation folder. Hierarchy of the folder is
.
βββ DRIVE
β βββ comparison
β βββ measures
β βββ vessels
βββ STARE
βββ comparison
βββ measures
βββ vessels
comparison: difference maps between V-GAN and gold standard
measures: AUC_ROC and AUC_PR curves
vessels: vessels superimposed on segmented masks
Area Under the Curve (AUC), Precision and Recall (PR), Receiver Operating Characteristic (ROC)
DRIVE Results
STARE Results
Difference Maps
DRIVE (top), STARE (bottom)
Green marks correct segmentation while blue and red indicate false positive and false negative
DRIVE Dataset
train_interval | Model | AUC_ROC | AUC_PR | Dice_coeff |
---|---|---|---|---|
1 | Pixel GAN | 0.9049 | 0.8033 | 0.3020 |
1 | Patch GAN-1 (10x10) | 0.9487 | 0.8431 | 0.7469 |
1 | Patch GAN-2 (80x80) | 0.9408 | 0.8257 | 0.7478 |
1 | Image GAN | 0.9280 | 0.8241 | 0.7839 |
100 | Pixel GAN | 0.9298 | 0.8228 | 0.7766 |
100 | Patch GAN-1 (10x10) | 0.9263 | 0.8159 | 0.7319 |
100 | patch GAN-2 (80x80) | 0.9312 | 0.8373 | 0.7520 |
100 | Image GAN | 0.9210 | 0.7883 | 0.7168 |
10000 | Pixel GAN | 0.9353 | 0.8692 | 0.7928 |
10000 | Patch GAN-1 (10x10) | 0.9445 | 0.8680 | 0.7938 |
10000 | patch GAN-2 (80x80) | 0.9525 | 0.8752 | 0.7957 |
10000 | Image GAN | 0.9509 | 0.8537 | 0.7546 |
STARE Dataset
train_interval | Model | AUC_ROC | AUC_PR | Dice_coeff |
---|---|---|---|---|
1 | Pixel GAN | 0.9368 | 0.8354 | 0.8063 |
1 | Patch GAN-1 (10x10) | 0.9119 | 0.7199 | 0.6607 |
1 | Patch GAN-2 (80x80) | 0.9053 | 0.7998 | 0.7902 |
1 | Image GAN | 0.9074 | 0.7452 | 0.7198 |
100 | Pixel GAN | 0.8874 | 0.7056 | 0.6616 |
100 | Patch GAN-1 (10x10) | 0.8787 | 0.6858 | 0.6432 |
100 | patch GAN-2 (80x80) | 0.9306 | 0.8066 | 0.7321 |
100 | Image GAN | 0.9099 | 0.7785 | 0.7117 |
10000 | Pixel GAN | 0.9317 | 0.8255 | 0.8107 |
10000 | Patch GAN-1 (10x10) | 0.9318 | 0.8378 | 0.8087 |
10000 | patch GAN-2 (80x80) | 0.9604 | 0.8600 | 0.7867 |
10000 | Image GAN | 0.9283 | 0.8395 | 0.8001 |
Note:
- Set higher training intervals between generator and discriminator, which can boost performance a little bit as paper mentioned. However, the mathematical theory behind this experimental results is not clear.
- The performance of V-GAN Tensorflow implementation has a gap compared with paper. Without fully fine-tuning and subtle difference in implementations may be the reasons.
Architectures
- Generator:
- Discriminator(Pixel):
- Discriminator(Patch-1):
- Discriminator(Patch-2):
- Discriminator(Image):
Tensorboard
AUC_ROC, AUC_PR, Dice_Coefficient, Accuracy, Sensitivity, and Specificity on validation dataset during training iterations Β
- AUC_ROC:
- AUC_PR:
- Dice_Coeffcient:
- Accuracy:
- Sensitivity:
- Specificity: