• Stars
    star
    120
  • Rank 286,542 (Top 6 %)
  • Language
    Python
  • License
    Other
  • Created over 2 years ago
  • Updated almost 2 years ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

TLDR is an unsupervised dimensionality reduction method that combines neighborhood embedding learning with the simplicity and effectiveness of recent self-supervised learning losses

TLDR: Twin Learning for Dimensionality Reduction

TLDR (Twin Learning for Dimensionality Reduction) is an unsupervised dimensionality reduction method that combines neighborhood embedding learning with the simplicity and effectiveness of recent self-supervised learning losses.

Inspired by manifold learning, TLDR uses nearest neighbors as a way to build pairs from a training set and a redundancy reduction loss to learn an encoder that produces representations invariant across such pairs. Similar to other neighborhood embeddings, TLDR effectively and unsupervisedly learns low-dimensional spaces where local neighborhoods of the input space are preserved; unlike other manifold learning methods, it simply consists of an offline nearest neighbor computation step and a straightforward learning process that does not require mining negative samples to contrast, eigendecompositions, or cumbersome optimization solvers.

More details and evaluation can be found in our TMLR paper.

diagram
Overview of TLDR: Given a set of feature vectors in a generic input space, we use nearest neighbors to define a set of feature pairs whose proximity we want to preserve. We then learn a dimensionality-reduction function (theencoder) by encouraging neighbors in the input space to havesimilar representations. We learn it jointly with an auxiliary projector that produces high dimensional representations, where we compute the Barlow Twins loss over the (dโ€ฒ ร— dโ€ฒ) cross-correlation matrix averaged over the batch.

Contents:

Installing the TLDR library

Requirements:

  • Python 3.6 or greater
  • PyTorch 1.8 or greater
  • numpy
  • FAISS
  • rich

In order to install the TLDR library, one should first make sure that FAISS and Pytorch are installed. We recommend using a new conda environment:

conda create --name ENV_NAME python=3.6.8
conda activate ENV_NAME
conda install -c pytorch faiss-gpu cudatoolkit=10.2
conda install pytorch torchvision torchaudio cudatoolkit=10.2 -c pytorch

After ensuring that you have installed both FAISS and numpy, you can install TLDR by using the two commands below:

git clone [email protected]:naver/tldr.git
python3 -m pip install -e tldr

Using the TLDR library

The TLDR library can be used to learn dimensionality reduction models using an API and functionality that mimics similar methods in the scikit-learn library, i.e. you can learn a dimensionality reduction on your training data using fit() and you can project new data using transform().

To illustrate the different functionalities we present a dummy example on randomly generated data. Let's import the library and generate some random training data (we will use 100K training examples with a dimensionality of 2048), i.e.:

import numpy as np
from tldr import TLDR

# Generate random data
X = np.random.rand(100000, 2048)  # replace with training (N x D) array

Instantiating a TLDR model

When instantiating a TLDR model one has to specify the output dimension (n_components), the number of nearest neighbors to use (n_neighbors) as well as the encoder and projector architectures that are specified as strings.

For this example we will learn a dimensionality reduction to 32 components, we will use the 10 nearest neighbors to sample positive pairs, and we will use a linear encoder and a multi-layer perceptron with one hidden layer of 2048 dimensions as a projector:

tldr = TLDR(n_components=32, n_neighbors=10, encoder='linear', projector='mlp-1-2048', device='cuda', verbose=2)

For a more detailed list of optional arguments please refer to the function documentation below; architecture specification string formatting guide is described in this section below.

Learning and applying the TLDR model

We learn the parameters of the dimensionality reduction model by using the fit() method:

tldr.fit(X, epochs=100, batch_size=1024, output_folder='data/', print_every=50)

By default, fit() first collects the k nearest neighbors for each training data point using FAISS and then optimizes the Barlow Twin loss using the batch size and number of epochs provided. Note that, apart from the dimensionality reduction function (the encoder), a projector function that is part of the training process is also learned (see also the Figure above); the projector is by default discarded after training.

Once the model has been trained we can use transform() to project the training data to the new learned space:

Z = tldr.transform(X, l2_norm=True)  # Returns (N x n_components) matrix

The optional l2_norm=True argument of transform() further applies L2 normalization to all features after projection.

Again, we refer the user to the functions' documentation below for argument details.

Saving/loading the model

The TLDR model and the array of nearest neighbors per training datapoint can be saved using the save() and save_knn() functions, repsectively:

tldr.save("data/inference_model.pth")
tldr.save_knn("data/knn.npy")

Note that by default the projector weights will not be saved. To also save the projector (e.g. for subsequent fine-tuning of the model) one must set the retain_projector=True argument when calling fit().

One can use the load() method to load a pre-trained model from disk. Using the init=True argument when loading also loads the hyper-parameters of the model:

X = np.random.rand(5000, 2048)
tldr = TLDR()
tldr.load("data/inference_model.pth", init=True)  # Loads both model parameters and weights
Z = tldr.transform(X, l2_norm=True)  # Returns (N x n_components) matrix

You can find this full example in scripts/dummy_example.py.

Documentation

TLDR(n_components, encoder, projector, n_neighbors=5, device='cpu', pin_memory=False)

Description of selected arguments (see code for full list):

  • n_components: output dimension
  • encoder: encoder network architecture specification string--see formatting guide (Default: 'linear').
  • projector: projector network architecture specification string--see formatting guide (Default: 'mlp-1-2048').
  • n_neighbors: number of nearest neighbors used to sample training pairs (Default: 5).
  • device: selects the device ['cpu', 'cuda'] (Default: cpu).
  • pin_memory: pin all data to the memory of the device (Default: False).
  • random_state: sets the random seed (Default: None).
  • knn_approximation: Amount of approximation to use during the knn computation; accepted values are [None, "low", "medium" and "high"] (Default: None). No approximation will calculate exact neighbors while setting the approximation to either low, medium or high will use product quantization and create the FAISS index using the index_factory with an "IVF1,PQ[X]" string, where X={32,16,8} for {"low","med","high"}. The PQ parameters are learned using 10% of the training data.
from tldr import TLDR

tlrd = TLDR(n_components=128, encoder='linear', projector='mlp-2-2048', n_neighbors=3, device='cuda')

fit(X, epochs=100, batch_size=1024, knn_graph=None, output_folder=None, snapshot_freq=None)

Parameters:

  • X: NxD training data array containing N training samples of dimension D.
  • epochs: number of training epochs (Default: 100).
  • batch_size: size of the training mini batch (Default: 1024).
  • knn_graph: Nxn_neighbors array containing the indices of nearest neighbors of each sample; if None it will be computed (Default: None).
  • output_folder: folder where the final model (and also the snapshots if snapshot_freq > 1) will be saved (Default: None).
  • snapshot_freq: number of epochs to save a new snapshot (Default: None).
  • print_every: prints useful training information every given number of steps (Default: 0).
  • retain_projector: flag so that the projector parameters are retained after training (Default: False).
from tldr import TLDR
import numpy as np

tldr = TLDR(n_components=32, encoder='linear', projector='mlp-2-2048')
X = np.random.rand(10000, 2048)
tldr.fit(X, epochs=50, batch_size=512, output_folder='data/', snapshot_freq=5, print_every=50)

transform(X, l2_norm=False)

Parameters:

  • X: NxD array containing N samples of dimension D.
  • l2_norm: l2 normalizes the features after projection. Default False.

Output:

  • Z: Nxn_components array
tldr.fit(X, epochs=100)
Z = tldr.transform(X, l2_norm=True)

save(path) and load(path)

  • save() saves to disk both model parameters and weights.
  • load() loads the weights of the model. If init=True it initializes the model with the hyper-parameters found in the file.
tldr = TLDR(n_components=32, encoder='linear', projector='mlp-2-2048')
tldr.fit(X, epochs=50, batch_size=512)
tldr.save("data/model.pth")  # Saves weights and params

tldr = TLDR()
tldr.load("data/model.pth", init=True)  # Initialize model with params in file and loads the weights

remove_projector()

Removes the projector head from the model. Useful for reducing the size of the model before saving it to disk. Note that you'll need the projection head if you want to resume training.

compute_knn(), save_knn() and load_knn()

tldr = TLDR(n_components=128, encoder='linear', projector='mlp-2-2048')
tldr.compute_knn(X)
tldr.fit(X, epochs=100)
tldr.save_knn("knn.npy")
tldr = TLDR(n_components=128, encoder='linear', projector='mlp-2-2048')
tldr.load_knn("knn.npy")
tldr.fit(X, epochs=100)

Architecture Specification Strings

You can specify the network configuration using a string with the following format:

'[NETWORK_TYPE]-[NUM_HIDDEN_LAYERS]-[NUM_DIMENSIONS_PER_LAYER]'

  • NETWORK_TYPE: three network types currently available:
    • linear: a linear function parametrized by a weight matrix W of size input_dim X num_components.
    • flinear: a factorized linear model in a sequence of linear layers, each composed of a linear layer followed by a batch normalization layer.
    • mlp: a multi-layer perceptron (MLP) with batch normalization and rectified linear units (ReLUs) as non-linearities.
  • NUM_HIDDEN_LAYERS: selects the number of hidden (ie. intermediate) layers for the factorized linear model and the MLP
  • NUM_DIMENSIONS_PER_LAYER: selects the dimensionality of the hidden layers.

For example, linear will use a single linear layer; flinear-1-512 will use a factorized linear layer with one hidden layer of 512 dimensions; and mlp-2-4096 will select a MLP composed of two hidden layers of 4096 dimensions each.

Citation

Please consider citing the following paper in your publications if this helps your research.

@article{kalantidis2022tldr,
 title = {TLDR: Twin Learning for Dimensionality Reduction},
 author = {Kalantidis, Y. and Lassance, C. and Almaz\'an, J. and Larlus, D.},
 journal={Transactions of Machine Learning Research},
 year={2022},
 url={https://openreview.net/forum?id=86fhqdBUbx},
}

Contributors

This code has been developed by Jon Almazan, Carlos Lassance, Yannis Kalantidis and Diane Larlus at NAVER Labs Europe.

More Repositories

1

billboard.js

๐Ÿ“Š Re-usable, easy interface JavaScript chart library based on D3.js
TypeScript
5,723
star
2

fe-news

FE ๊ธฐ์ˆ  ์†Œ์‹ ํ๋ ˆ์ด์…˜ ๋‰ด์Šค๋ ˆํ„ฐ
5,274
star
3

dust3r

DUSt3R: Geometric 3D Vision Made Easy
Python
3,409
star
4

egjs-flicking

๐ŸŽ  โ™ป๏ธ Everyday 30 million people experience. It's reliable, flexible and extendable carousel.
TypeScript
2,551
star
5

egjs-infinitegrid

A module used to arrange card elements including content infinitely on a grid layout.
TypeScript
1,869
star
6

ngrinder

enterprise level performance testing solution
Java
1,788
star
7

d2codingfont

D2 Coding ๊ธ€๊ผด
1,774
star
8

egjs

Javascript components group that brings easiest and fastest way to build a web application in your way.
JavaScript
922
star
9

biobert-pretrained

BioBERT: a pre-trained biomedical language representation model for biomedical text mining
632
star
10

sqlova

Python
625
star
11

splade

SPLADE: sparse neural search (SIGIR21, SIGIR22)
Python
618
star
12

deep-image-retrieval

End-to-end learning of deep visual representations for image retrieval
Python
615
star
13

r2d2

Python
442
star
14

fixture-monkey

Let Fixture Monkey generate test instances including edge cases automatically
Java
440
star
15

egjs-view360

360 integrated viewing solution
TypeScript
438
star
16

kapture

kapture is a file format as well as a set of tools for manipulating datasets, and in particular Visual Localization and Structure from Motion data.
Python
429
star
17

scavenger

A runtime dead code analysis tool
Java
383
star
18

yobi

Project hosting software - Deprecated
Java
379
star
19

roma

RoMa: A lightweight library to deal with 3D rotations in PyTorch.
Python
364
star
20

lispe

An implementation of a full fledged Lisp interpreter with Data Structure, Pattern Programming and High level Functions with Lazy Evaluation ร  la Haskell.
C
357
star
21

lucy-xss-filter

HTML
319
star
22

arcus

ARCUS is the NAVER memcached with lists, sets, maps and b+trees. http://naver.github.io/arcus
Shell
300
star
23

spring-jdbc-plus

Spring JDBC Plus
Java
257
star
24

egjs-grid

A component that can arrange items according to the type of grids
TypeScript
253
star
25

kapture-localization

Provide mapping and localization pipelines based on kapture format
Python
251
star
26

android-imagecropview

android image crop library
Java
250
star
27

smarteditor2

Javascript WYSIWYG HTML editor
JavaScript
241
star
28

lucy-xss-servlet-filter

Java
237
star
29

claf

CLaF: Open-Source Clova Language Framework
Python
215
star
30

eslint-config-naver

Naver JavaScript Coding Conventions rules for eslint
JavaScript
205
star
31

kor2vec

OOV์—†์ด ๋น ๋ฅด๊ณ  ์ •ํ™•ํ•œ ํ•œ๊ตญ์–ด Embedding ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ
Python
197
star
32

tamgu

Tamgu (ํƒ๊ตฌ), a FIL programming language: Functional, Imperative, Logical all in one for annotation and data augmentation
C++
186
star
33

nlp-challenge

NLP Shared tasks (NER, SRL) using NSML
Python
176
star
34

nbase-arc

nbase-arc is an open source distributed memory store based on Redis
C
171
star
35

nanumfont

170
star
36

egjs-view3d

Fast & customizable 3D model viewer for everyone
TypeScript
170
star
37

hackday-conventions-java

์บ ํผ์Šค ํ•ต๋ฐ์ด Java ์ฝ”๋”ฉ ์ปจ๋ฒค์…˜
169
star
38

egjs-axes

A module used to change the information of user action entered by various input devices such as touch screen or mouse into the logical virtual coordinates.
TypeScript
150
star
39

cgd

Combination of Multiple Global Descriptors for Image Retrieval
Python
144
star
40

croco

Python
137
star
41

volley-extensions

Volley Extensions v2.0.0. ( Volleyer, Volley requests, Volley caches, Volley custom views )
Java
134
star
42

naver-openapi-guide

CSS
129
star
43

fire

Python
119
star
44

grabcutios

Image segmentation using GrabCut algorithm for iOS
C++
118
star
45

sling

C++
117
star
46

gdc

Code accompanying our papers on the "Generative Distributional Control" framework
Python
116
star
47

naveridlogin-sdk-android

๋„ค์ด๋ฒ„ ์•„์ด๋””๋กœ ๋กœ๊ทธ์ธ SDK (์•ˆ๋“œ๋กœ์ด๋“œ)
Kotlin
112
star
48

PoseGPT

Python
106
star
49

egjs-conveyer

Conveyer adds Drag gestures to your Native Scroll.
TypeScript
103
star
50

egjs-agent

Extracts browser and operating system information from the user agent string or user agent object(userAgentData).
TypeScript
100
star
51

spring-batch-plus

Add useful features to spring batch
Kotlin
100
star
52

cfcs

Write once, create framework components that supports React, Vue, Svelte, and more.
TypeScript
98
star
53

searchad-apidoc

Java
96
star
54

dope

Python
91
star
55

multi-hmr

Pytorch demo code and models for Multi-HMR
Python
87
star
56

imagestabilizer

C++
77
star
57

posescript

Python
76
star
58

guitar

AutoIt
76
star
59

arcus-memcached

ARCUS memory cache server
C
69
star
60

disco

A Toolkit for Distributional Control of Generative Models
Python
68
star
61

svc

Easy and intuitive pattern for Android
Kotlin
63
star
62

cover-checker

Check your pull request code coverage
Java
63
star
63

storybook-addon-preview

Storybook Addon Preview can show user selected knobs in various framework code in Storybook
TypeScript
63
star
64

egjs-list-differ

โž•โž–๐Ÿ”„ A module that checks the diff when values are added, removed, or changed in an array.
TypeScript
61
star
65

egjs-imready

I'm Ready to check if the images or videos are loaded!
TypeScript
59
star
66

egjs-flicking-plugins

Plugins for @egjs/flicking
TypeScript
59
star
67

naveridlogin-sdk-ios

Objective-C
58
star
68

clova-face-kit

On-device lightweight face recognition. Available on Android, iOS, WASM, Python.
57
star
69

prism-live-studio

C++
56
star
70

rye

RYE, Native Sharding RDBMS
C
54
star
71

hubblemon

Python
54
star
72

zeplin-flutter-gen

๐Ÿš€The Flutter dart code generator from zeplin. ex) Container, Text, Color, TextStyle, ... - Save your time.
JavaScript
54
star
73

egjs-visible

A class that checks if an element is visible in the base element or viewport.
HTML
52
star
74

aqm-plus

PyTorch code for Large-Scale Answerer in Questioner's Mind for Visual Dialog Question Generation (AQM+) (ICLR 2019)
Python
50
star
75

arcus-java-client

ARCUS Java client
Java
49
star
76

isometrizer

Isometrizer turns your DOM elements into isometric projection
TypeScript
47
star
77

garnet

Python
45
star
78

jindojs-jindo

Jindo JavaScript Framework
JavaScript
44
star
79

artemis

Official code release for ARTEMIS: Attention-based Retrieval with Text-Explicit Matching and Implicit Similarity (published at ICLR 2022)
Python
42
star
80

covid19-nmt

Multi-lingual & multi-domain (specialisation for biomedical data) translation model
Python
40
star
81

react-sample-code

์ด ํ”„๋กœ์ ํŠธ๋Š” hello world์— ๊ณต๊ฐœํ•œ React ๊ฐœ๋ฐœ ๊ฐ€์ด๋“œ์— ํ•„์š”ํ•œ ์ƒ˜ํ”Œ ์ฝ”๋“œ์ž…๋‹ˆ๋‹ค.
JavaScript
39
star
82

passport-naver

A passport strategy for Naver OAuth 2.0
JavaScript
38
star
83

hadoop

Public hadoop release repository
Java
38
star
84

kaist-oss-course

Introduction to Open Source Software class @ KAIST 2016
38
star
85

pump

Python
38
star
86

egjs-component

A class used to manage events in a component like DOM
TypeScript
38
star
87

graphql-dataloader-mongoose

graphql-dataloader-mongoose is a DataLoader generator based on an existing Mongoose model
TypeScript
38
star
88

egjs-persist

Provide cache interface to handle persisted data among history navigation.
JavaScript
38
star
89

posebert

Python
37
star
90

naverspeech-sdk-ios

Swift
32
star
91

reflect

C++ class reflection library without RTTI.
C++
32
star
92

android-utilset

Utilset is collections of useful functions to save your valuable time.
Java
32
star
93

cafe-sdk-unity

31
star
94

naver-spring-batch-ex

Java
31
star
95

image-maps

jquery plugin which can be partially linked to the image
JavaScript
31
star
96

whale-browser-developers

Documents for Whale browser developers.
28
star
97

ai-hackathon

๋„ค์ด๋ฒ„ AI Hackathon_AI Vision!
Python
28
star
98

image-sprite-webpack-plugin

A webpack plugin that generates spritesheets from your stylesheets.
JavaScript
28
star
99

oasis

Code for the paper "On the Road to Online Adaptation for Semantic Image Segmentation", CVPR 2022
Python
27
star
100

react-native-image-modifier

Modify local images by React-native module
Java
25
star