• Stars
    star
    281
  • Rank 143,979 (Top 3 %)
  • Language
    Python
  • Created over 7 years ago
  • Updated over 7 years ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

An implementation of SRGAN model in Keras

Super Resolution using Generative Adversarial Networks

This is an implementation of the SRGAN model proposed in the paper Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network in Keras. Note that this project is a work in progress.

A simplified view of the model can be seen as below:

Implementation Details

The SRGAN model is built in stages within models.py. Initially, only the SR-ResNet model is created, to which the VGG network is appended to create the pre-training model. The VGG weights are freezed as we will not update these weights.

In the pre-train mode:

  1. The discriminator model is not attached to the entire network. Therefore it is only the SR + VGG model that will be pretrained first.

  2. During pretraining, the VGG perceptual losses will be used to train (using the ContentVGGRegularizer) and TotalVariation loss (using TVRegularizer). No other loss (MSE, Binary crosss entropy, Discriminator) will be applied.

  3. Content Regularizer loss will be applied to the VGG Convolution 2-2 layer

  4. After pre training the SR + VGG model, we will pretrain the discriminator model.

  5. During discriminator pretraining, model is Generaor + Discriminator. Only binary cross entropy loss is used to train the Discriminator network.

In the full train mode:

  1. The discriminator model is attached to the entire network. Therefore it creates the SR + GAN + VGG model (SRGAN)
  2. Discriminator loss is also added to the VGGContentLoss and TVLoss.
  3. Content regularizer loss is applied to the VGG Convolution 5-3 layer. (VGG 16 is used instead of 19 for now)

Usage

Currently, models.py contains most of the code to train and create the models. To use different modes, uncomment the parts of the code that you need.

Note the difference between the *_network objects and *_model objects.

  • The *_network objects refer to the helper classes which create and manage the Keras models, load and save weights and set whether the model can be trained or not.
  • The *_models objects refer to the underlying Keras model.

Note: The training images need to be stored in a subdirectory. Assume the path to the images is /path-to-dir/path-to-sub-dir/*.png, then simply write the path as coco_path = /path-to-dir. If this does not work, try coco_path = /path-to-dir/ with a trailing slash (/)

To just create the pretrain model:

srgan_network = SRGANNetwork(img_width=32, img_height=32, batch_size=1)
srgan_model = srgan_network.build_srgan_pretrain_model()

# Plot the model
from keras.utils.visualize_util import plot
plot(srgan_model, to_file='SRGAN.png', show_shapes=True)

To pretrain the SR network:

srgan_network = SRGANNetwork(img_width=32, img_height=32, batch_size=1)
srgan_network.pre_train_srgan(iamges_path, nb_epochs=1, nb_images=50000)

** NOTE **: There may be many cases where generator initializations may lead to completely solid validation images. Please check the first few iterations to see if the validation images are not solid images.

To counteract this, a pretrained generator model has been provided, from which you can restart training. Therefore the model can continue learning without hitting a bad initialization.

To pretrain the Discriminator network:

srgan_network = SRGANNetwork(img_width=32, img_height=32, batch_size=1)
srgan_network.pre_train_discriminator(iamges_path, nb_epochs=1, nb_images=50000, batchsize=16)

To train the full network (Does NOT work properly right now, Discriminator is not correctly trained):

srgan_network = SRGANNetwork(img_width=32, img_height=32, batch_size=1)
srgan_network.train_full_model(coco_path, nb_images=80000, nb_epochs=10)

Benchmarks

Currently supports validation agains Set5, Set14 and BSD 100 dataset images. To download the images, each of the 3 dataset have scripts called download_*.py which must be run before running benchmark_test.py test.

Current Scores (Due to RGB grid and Blurred restoration):

SR ResNet:

  • Set5 : Average PSNR of Set5 validation images : 22.1211430348
  • Set14 : Average PSNR of Set5 validation images : 20.3971611357
  • BSD100 : Average PSNR of BSD100 validation images : 20.9544390316

Drawbacks:

  • Since keras has internal checks for batch size, we have to bypass an internal keras check called check_array_length(), which checks the input and output batch sizes. As we provide the original images to Input 2, batch size doubles. This causes an assertion error in internal keras code. For now, we rewrite the fit logic of keras in keras_training_ops.py and use the bypass fit functions.
  • For some reason, the Deconvolution networks are not learning the upscaling function properly. This causes grids to form throughout the upscaled image. This is possibly due to the large (4x) upscaling procedure, but the Twitter team was able to do it.

Plans

The codebase is currently very chaotic, since I am focusing on correct implementation before making the project better. Therefore, expect the code to drastically change over commits.

Some things I am currently trying out:

  • Training the discriminator model separately properly.
  • Training the discriminator using soft labels and adversarial loss.
  • Properly train SRGAN (SR ResNet + VGG + Discriminator) model.
  • Fix the pixel grid formation when upscaling the image. (With Nearest Neighbour Upscaling).
  • Replacing the 2 deconv layers for a nearest neighbour upsampling layers.
  • Improve docs & instructions

Discussion

There is an ongoing discussion at keras-team/keras#3940 where I detail some of the outputs and attempts to correct the errors.

Requirements

  • Theano (master branch)
  • Keras 1.2.0 +

More Repositories

1

Neural-Style-Transfer

Keras Implementation of Neural Style Transfer from the paper "A Neural Algorithm of Artistic Style" (http://arxiv.org/abs/1508.06576) in Keras 2.0+
Jupyter Notebook
2,265
star
2

Image-Super-Resolution

Implementation of Super Resolution CNN in Keras.
Python
828
star
3

neural-image-assessment

Implementation of NIMA: Neural Image Assessment in Keras
Python
764
star
4

LSTM-FCN

Codebase for the paper LSTM Fully Convolutional Networks for Time Series Classification
Python
735
star
5

DenseNet

DenseNet implementation in Keras
Python
707
star
6

MLSTM-FCN

Multivariate LSTM Fully Convolutional Networks for Time Series Classification
Python
476
star
7

neural-architecture-search

Basic implementation of [Neural Architecture Search with Reinforcement Learning](https://arxiv.org/abs/1611.01578).
Python
424
star
8

keras-squeeze-excite-network

Implementation of Squeeze and Excitation Networks in Keras
Python
401
star
9

Inception-v4

Inception-v4, Inception - Resnet-v1 and v2 Architectures in Keras
Python
384
star
10

Keras-Classification-Models

Collection of Keras models used for classification
Python
317
star
11

Snapshot-Ensembles

Snapshot Ensemble in Keras
Python
304
star
12

keras-non-local-nets

Keras implementation of Non-local Neural Networks
Python
291
star
13

keras-one-cycle

Implementation of One-Cycle Learning rate policy (adapted from Fast.ai lib)
Python
285
star
14

tf-TabNet

A Tensorflow 2.0 implementation of TabNet.
Python
235
star
15

Keras-ResNeXt

Implementation of ResNeXt models from the paper Aggregated Residual Transformations for Deep Neural Networks in Keras 2.0+.
Python
223
star
16

tfdiffeq

Tensorflow implementation of Ordinary Differential Equation Solvers with full GPU support
Python
213
star
17

Keras-NASNet

"NASNet" models in Keras 2.0+ with weights
Python
200
star
18

keras-efficientnets

Keras Implementation of EfficientNets
Python
188
star
19

tf_SIREN

Tensorflow 2.0 implementation of Sinusodial Representation networks (SIREN)
Python
148
star
20

keras-coordconv

Keras implementation of CoordConv for all Convolution layers
Python
148
star
21

MobileNetworks

Keras implementation of Mobile Networks
Python
131
star
22

keras-adabound

Keras implementation of AdaBound
Python
130
star
23

keras-attention-augmented-convs

Keras implementation of Attention Augmented Convolutional Neural Networks
Python
120
star
24

progressive-neural-architecture-search

Implementation of Progressive Neural Architecture Search in Keras and Tensorflow
Python
119
star
25

Keras-DualPathNetworks

Dual Path Networks for Keras 2.0+
Python
114
star
26

Wide-Residual-Networks

Wide Residual Networks in Keras
Python
112
star
27

Fast-Neural-Style

Implementation of "Perceptual Losses for Real-Time Style Transfer and Super-Resolution" in Keras
Python
109
star
28

Keras-Group-Normalization

A Keras implementation of https://arxiv.org/abs/1803.08494
Python
103
star
29

BatchRenormalization

Batch Renormalization algorithm implementation in Keras
Python
98
star
30

Nested-LSTM

Keras implementation of Nested LSTMs
Python
90
star
31

keras-SRU

Implementation of Simple Recurrent Unit in Keras
Python
89
star
32

Fully-Connected-DenseNets-Semantic-Segmentation

Fully Connected DenseNet for Image Segmentation (https://arxiv.org/pdf/1611.09326v1.pdf)
Python
84
star
33

keras-LAMB-Optimizer

Implementation of the LAMB optimizer for Keras from the paper "Reducing BERT Pre-Training Time from 3 Days to 76 Minutes"
Python
76
star
34

tf-eager-examples

A set of simple examples ported from PyTorch for Tensorflow Eager Execution
Jupyter Notebook
74
star
35

keras_rectified_adam

Implementation of Rectified Adam in Keras
Python
69
star
36

Keras-IndRNN

Implementation of IndRNN in Keras
Python
67
star
37

LSTM-FCN-Ablation

Repository for the ablation study of "Long Short-Term Memory Fully Convolutional Networks for Time Series Classification"
Python
56
star
38

keras-octconv

Keras implementation of Octave Convolutions
Python
53
star
39

keras-global-context-networks

Keras implementation of Global Context Attention blocks
Python
45
star
40

Neural-Style-Transfer-Windows

Windows Form application written in C# to ease usage of neural style transfer script
Python
43
star
41

tf_fourier_features

Tensorflow 2.0 implementation of Fourier Feature Mapping Networks.
Python
42
star
42

Keras-Multiplicative-LSTM

Miltiplicative LSTM for Keras 2.0+
Python
42
star
43

keras_mixnets

Keras Implementation of MixNets: Mixed Depthwise Convolutions
Python
39
star
44

Keras-just-another-network-JANET

Keras implementation of [The unreasonable effectiveness of the forget gate](https://arxiv.org/abs/1804.04849)
Jupyter Notebook
35
star
45

keras-switchnorm

Switch Normalization implementation for Keras 2+
Python
31
star
46

keras-neural-alu

A Keras implementation of Neural Arithmatic and Logical Unit
Python
27
star
47

keras-mobile-colorizer

U-Net Model conditioned with MobileNet features for Grayscale -> Color mapping
Python
25
star
48

Deep-Columnar-Convolutional-Neural-Network

Deep Columnar Convolutional Neural Network architecture, which is based on Multi Columnar DNN (Ciresan 2012).
Python
24
star
49

keras-SparseNet

Keras Implementation of SparseNets
Python
23
star
50

Residual-of-Residual-Networks

Residual Network of Residual Networks in Keras
Python
23
star
51

pyshac

A Python library for the Sequential Halving and Classification algorithm
Python
21
star
52

Adversarial-Attacks-Time-Series

Codebase for the paper "Adversarial Attacks on Time Series"
Python
20
star
53

keras_novograd

Keras implementation of NovoGrad
Python
20
star
54

simple_diffusion

Simple notebooks to learn diffusion models on toy datasets
Jupyter Notebook
17
star
55

keras-normalized-optimizers

Wrapper for Normalized Gradient Descent in Keras
Jupyter Notebook
17
star
56

keras-padam

Keras implementation of Padam from "Closing the Generalization Gap of Adaptive Gradient Methods in Training Deep Neural Networks"
Python
17
star
57

pytorch_odegan

Partial implementation of ODE-GAN technique from the paper Training Generative Adversarial Networks by Solving Ordinary Differential Equations
Python
16
star
58

tf-sha-rnn

Tensorflow port implementation of Single Headed Attention RNN
Python
16
star
59

warprnnt_numba

WarpRNNT loss ported in Numba CPU/CUDA for Pytorch
Jupyter Notebook
16
star
60

Advanced_Machine_Learning

Python
16
star
61

keras-minimal-rnn

Keras implementation of MinimalRNN: Toward More Interpretable and Trainable Recurrent Neural Networks
Python
16
star
62

dtw-numba

Implementation of Dynamic Time Warping algorithm with speed improvements based on Numba.
Python
15
star
63

TweetSentimentAnalysis

CS583 course project
Python
14
star
64

tf_GON

Tensorflow 2.x implementation of Gradient Origin Networks
Python
13
star
65

lambda_networks_pt

Lambda Networks implemented in PyTorch
Python
12
star
66

tf_neural_deconvolution

Neural Deconvolutions in Tensorflow
Python
12
star
67

Python-Work

Python scripts to facilitate easy working
Jupyter Notebook
11
star
68

PyCTakesParser

Utilities to parse the output of cTAKES
Python
10
star
69

tf_star_rnn

Tensorflow 2.0 implementation of STAR RNN
Python
10
star
70

Deep-Dream

Deep Dream implementation in Keras
Python
9
star
71

Kaggle

Kaggle competition library. Uses Python 3.4.1 with almost all known python libraries for Machine Learning
Python
7
star
72

Music-Recognition

C# project to perform Frequency Analysis of music
C#
5
star
73

Rabin-Karp-String-Matching

C
4
star
74

Data-Science

Library of Data Science classes
Python
3
star
75

diffusion_model_nemo

Python
3
star
76

Ragial-Searcher

The Core Java library used to parse and store Ragial.com data
HTML
3
star
77

MSApriori

Multiple support apriori algorithm in Java
Java
3
star
78

RagialNotifier

Android App to parse ragial.com using the Ragial Searcher library to track items and notify the user if the item is on sale. Developed for the game Ragnarok Online, developed and owned by Gravity Inc.
Java
3
star
79

braindrain-uncommonhacks

JavaScript
2
star
80

IDS-Course-Project

Intro to Data Science Project
Python
2
star
81

ML-Tools

Python
2
star
82

Tiger-Game

Tiger Game in Python 2.7 / 3.4+
Python
2
star
83

8086-Microprocessor

An attempt to emulate an 8086 microprocessor, with its ASM instruction set.
Java
2
star
84

titu1994.github.io

HTML
2
star
85

Adaptive-Sorting-Algorithm

Analysis and implementation of Machine Learning Decision Tree to classify best algorithm for given data set
C#
2
star
86

Optimal-Binary-Search-Tree

C
2
star
87

Naive-String-Matching

C
2
star
88

Recurstion-C

Recursion in C
C
2
star
89

Java-Adaptive-Sorting-Algorithm

Adaptive Sorting Algorithm using Decision Trees to decide which algorithm will be optimal to sort a given dataset.
Java
2
star
90

Rate-Monotonic-Scheduling-Algorithm

Java
1
star
91

WT-Mini-Project

CSS
1
star
92

Kruskals-Algorithm

C
1
star
93

Stack

Stack
C
1
star
94

Doublu-Linked-List

Doubly Linked List
C
1
star
95

CircularLinkedList

Circular Linked List in C
C
1
star
96

SOOAD-Mini-Project

Java
1
star
97

Knuth-Morris-Pratt

C
1
star
98

MyLib

1
star
99

Stack-with-Linked-List

Stack with Linked List
C
1
star
100

College

College Java
Java
1
star