• Stars
    star
    1,121
  • Rank 41,479 (Top 0.9 %)
  • Language
    Python
  • License
    MIT License
  • Created almost 7 years ago
  • Updated over 5 years ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Implementation of triplet loss in TensorFlow

Triplet loss in TensorFlow Build Status

Author: Olivier Moindrot

This repository contains a triplet loss implementation in TensorFlow with online triplet mining. Please check the blog post for a full description.

The code structure is adapted from code I wrote for CS230 in this repository at tensorflow/vision. A set of tutorials for this code can be found here.

Requirements

We recommend using python3 and a virtual environment. The default venv should be used, or virtualenv with python3.

python3 -m venv .env
source .env/bin/activate
pip install -r requirements_cpu.txt

If you are using a GPU, you will need to install tensorflow-gpu so do:

pip install -r requirements_gpu.txt

Triplet loss

triplet-loss-img
Triplet loss on two positive faces (Obama) and one negative face (Macron)

The interesting part, defining triplet loss with triplet mining can be found in model/triplet_loss.py.

Everything is explained in the blog post.

To use the "batch all" version, you can do:

from model.triplet_loss import batch_all_triplet_loss

loss, fraction_positive = batch_all_triplet_loss(labels, embeddings, margin, squared=False)

In this case fraction_positive is a useful thing to plot in TensorBoard to track the average number of hard and semi-hard triplets.

To use the "batch hard" version, you can do:

from model.triplet_loss import batch_hard_triplet_loss

loss = batch_hard_triplet_loss(labels, embeddings, margin, squared=False)

Training on MNIST

To run a new experiment called base_model, do:

python train.py --model_dir experiments/base_model

You will first need to create a configuration file like this one: params.json. This json file specifies all the hyperparameters for the model. All the weights and summaries will be saved in the model_dir.

Once trained, you can visualize the embeddings by running:

python visualize_embeddings.py --model_dir experiments/base_model

And run tensorboard in the experiment directory:

tensorboard --logdir experiments/base_model

Here is the result (link to gif):

embeddings-img
Embeddings of the MNIST test images visualized with T-SNE (perplexity 25)

Test

To run all the tests, run this from the project directory:

pytest

To run a specific test:

pytest model/tests/test_triplet_loss.py

Resources