• Stars
    star
    112
  • Rank 312,240 (Top 7 %)
  • Language
    Python
  • Created over 4 years ago
  • Updated over 1 year ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

๐Ÿ›’ Simple recommender with matrix factorization, graph, and NLP. Beating the regular collaborative filtering baseline.

recsys-nlp-graph

Undocumented code for personal project on simple recsys via matrix factorization (part 1), and nlp and graph techniques (part 2). Sharing as part of meet-up follow along.

Associated articles:

Talk and Slides:

Data

Electronics and books data from the Amazon dataset (May 1996 โ€“ July 2014) was used. Here's how an example JSON entry looks like.

{ 
"asin": "0000031852",
"title": "Girls Ballet Tutu Zebra Hot Pink",
"price": 3.17,
"imUrl": "http://ecx.images-amazon.com/images/I/51fAmVkTbyL._SY300_.jpg",
"relatedโ€:
    { "also_bought":[
		  	"B00JHONN1S",
		  	"B002BZX8Z6",
		  	"B00D2K1M3O", 
		  	...
		  	"B007R2RM8W"
                    ],
      "also_viewed":[ 
		  	"B002BZX8Z6",
		  	"B00JHONN1S",
		  	"B008F0SU0Y",
		  	...
		  	"B00BFXLZ8M"
                     ],
      "bought_together":[ 
		  	"B002BZX8Z6"
                     ]
    },
"salesRank":
    { 
      "Toys & Games":211836
    },
"brand": "Coxlures",
"categories":[ 
	    [ "Sports & Outdoors",
	      "Other Sports",
	      "Dance"
	    ]
    ]
}

Comparing Matrix Factorization to Skip-gram (Node2Vec)

Overall results for Electronics dataset

All Products Seen Products Only Runtime (min)
PyTorch Matrix Factorization 0.7951 - 45
Node2Vec NA NA NA
Gensim Word2Vec 0.9082 0.9735 2.58
PyTorch Word2Vec 0.9554 0.9855 23.63
PyTorch Word2Vec with Side Info NA NA NA
PyTorch Matrix Factorization With Sequences 0.9320 - 70.39
Alibaba Paper* 0.9327 - -

Overall results for Books dataset

All Products Seen Products Only Runtime (min)
PyTorch Matrix Factorization 0.4996 - 1353.12
Gensim Word2Vec 0.9701 0.9892 16.24
PyTorch Word2Vec 0.9775 - 122.66
PyTorch Word2Vec with Side Info NA NA NA
PyTorch Matrix Factorization With Sequences 0.7196 - 1393.08

*Billion-scale Commodity Embedding for E-commerce Recommendation in Alibaba

1. Matrix Factorization (iteratively pair by pair)

At a high level, for each pair:

  • Get the embedding for each product
  • Multiply embeddings and sum the resulting vector (this is the prediction)
  • Reduce the difference between predicted score and actual score (via gradient descent and a loss function like mean squared error or BCE)

Here's some pseudo-code on how it would work.

for product_pair, label in train_set:
    # Get embedding for each product
    product1_emb = embedding(product1)
    product2_emb = embedding(product2)

    # Predict product-pair score (interaction term and sum)
    prediction = sig(sum(product1_emb * product2_emb, dim=1))
    l2_reg = lambda * sum(embedding.weight ** 2) 

    # Minimize loss
    loss = BinaryCrossEntropyLoss(prediction, label)
    loss += l2_reg

    loss.backward()
    optimizer.step()

For the training schedule, we run it over 5 epochs with cosine annealing. For each epoch, learning rate starts high (0.01) and drops rapidly to a minimum value near zero, before being reset for to the next epoch.

One epoch seems sufficient to achive close to optimal ROC-AUC.

However, if we look at the precision-recall curves below, we see that at around 0.5 we hit the โ€œcliff of deathโ€. If we estimate the threshold slightly too low, precision drops from close to 1.0 to 0.5; slightly too high and recall is poor.

2. Matrix Factorization with Bias

Adding bias reduces the steepness of the curves where they intersect, making it more production-friendly. (Though AUC-ROC decreases slightly, this implementation is preferable.)

3. Node2Vec

I tried using the implementation of Node2Vec here but it was too memory intensive and slow. It didn't run to completion, even on a 64gb instance.

Digging deeper, I found that its approach to generating sequences was traversing the graph. If you allowed networkx to use multiple threads, it would spawn multiple processes to create sequences and cache them temporarily in memory. In short, very memory hungry. Overall, this didnโ€™t work for the datasets I had.

4. gensim.word2vec

Gensim has an implementation of w2v that takes in a list of sequences and can be multi-threaded. It was very easy to use and the fastest to complete five epochs.

But the precision-recall curve shows a sharp cliff around threshold == 0.73. This is due to out-of-vocabulary products in our validation datasets (which don't have embeddings).

If we only evaluate in-vocabulary items, performance improves significantly.

5. PyTorch word2vec

We implement Skip-gram in PyTorch. Here's some simplified code of how it looks.

class SkipGram(nn.Module):
    def __init__(self, emb_size, emb_dim):
        self.center_embeddings = nn.Embedding(emb_size, emb_dim, sparse=True)
        self.context_embeddings = nn.Embedding(emb_size, emb_dim, sparse=True)

    def forward(self, center, context, neg_context):
        emb_center, emb_context, emb_neg_context = self.get_embeddings()

        # Get score for positive pairs
        score = torch.sum(emb_center * emb_context, dim=1)
        score = -F.logsigmoid(score)

        # Get score for negative pairs
        neg_score = torch.bmm(emb_neg_context, emb_center.unsqueeze(2)).squeeze()
        neg_score = -torch.sum(F.logsigmoid(-neg_score), dim=1)

        # Return combined score
        return torch.mean(score + neg_score)

It performed better than gensim when considering all products.

If considering only seen products, it's still an improvement, but less dramatic.

When examining the learning curves, it seems that a single epoch is sufficient. In contrast to the learning curves from matrix factorization (implementation 1), the AUC-ROC doesn't drop drastically with each learning rate reset.

6. PyTorch word2vec with side info

Why did we build the skip-gram model from scratch? Because we wanted to extend it with side information (e.g., brand, category, price).

B001T9NUFS -> B003AVEU6G -> B007ZN5Y56 ... -> B007ZN5Y56
Television    Sound bar     Lamp              Standing Fan
Sony          Sony          Phillips          Dyson
500 โ€“ 600     200 โ€“ 300     50 โ€“ 75           300 - 400

Perhaps by learning on these we can create better embeddings?

Unfortunately, it didn't work out. Here's how the learning curve looks.

One possible reason for this non-result is the sparsity of the meta data. Out of 418,749 electronic products, we only had metadata for 162,023 (39%). Of these, brand was 51% empty.

7. Sequences + Matrix Factorization

Why did the w2v approach do so much better than matrix factorization? Was it due to the skipgram model, or due to the training data format (i.e., sequences)?

To understand this better, I tried the previous matrix factorization with bias implementation (AUC-ROC = 0.7951) with the new sequences and dataloader. It worked very well.

Oddly though, the matrix factorization approach still exhibits the effect of โ€œforgettingโ€ as learning rate resets with each epoch (Fig 9.), though not as pronounced as Figure 3 in the previous post.

I wonder if this is due to using the same embeddings for both center and context.

More Repositories

1

applied-ml

๐Ÿ“š Papers & tech blogs by companies sharing their work on data science & machine learning in production.
24,324
star
2

open-llms

๐Ÿ“‹ A list of open LLMs available for commercial use.
10,867
star
3

ml-surveys

๐Ÿ“‹ Survey papers summarizing advances in deep learning, NLP, CV, graphs, reinforcement learning, recommendations, graphs, etc.
2,630
star
4

ml-design-docs

๐Ÿ“ Design doc template & examples for machine learning systems (requirements, methodology, implementation, etc.)
395
star
5

1-on-1s

๐ŸŒฑ 1-on-1 questions and resources from my time as a manager.
310
star
6

testing-ml

๐Ÿ” Minimal examples of machine learning tests for implementation, behaviour, and performance.
Python
199
star
7

obsidian-copilot

๐Ÿค– A prototype assistant for writing and thinking
Python
186
star
8

applyingml

๐Ÿ“Œ Papers, guides, and mentor interviews on applying machine learning for ApplyingML.comโ€”the ghost knowledge of machine learning.
JavaScript
160
star
9

papermill-mlflow

๐Ÿงช Simple data science experimentation & tracking with jupyter, papermill, and mlflow.
Jupyter Notebook
152
star
10

python-collab-template

๐Ÿ›  Python project template with unit tests, code coverage, linting, type checking, Makefile wrapper, and GitHub Actions.
Python
129
star
11

llm-paper-notes

Notes from the Latent Space paper club. Follow along or start your own!
73
star
12

fastapi-html

Sample repository demonstrating how to use FastAPI to serve HTML web apps.
Python
62
star
13

eugeneyan

Python
38
star
14

poc-docker-template

Simple template showing how to set up docker for reproducible data science with Jupyter notebooks.
Jupyter Notebook
21
star
15

text-to-image

Jupyter Notebook
13
star
16

nocode-ml

๐Ÿ˜ End-to-end machine learning; "no code" required!
12
star
17

discord-llm

Experimenting with LLMs to Research, Reflect, and Plan (LLM assistants, retrieval, and Discord integration)
Jupyter Notebook
11
star
18

learning-typescript

JavaScript
10
star
19

design-patterns

Java
7
star
20

deep-rl

Repository for deep reinforcement learning with OpenAI
Python
6
star
21

testing-pipelines

Python
6
star
22

kaggle_springleaf

Code for Kaggle Springleaf Email Prediction Challenge
Python
5
star
23

Computational-Thinking-and-Data-Science

edX: Introduction to Computational Thinking and Data Science (Oct 2014)
Python
5
star
24

ama

Ask Me Anything
4
star
25

Mining-Massive-Datasets

Coursera: Mining Massive Datasets (Sep 2014)
R
4
star
26

Time-Series-Analysis

Simple forecasting with Regression Model
R
3
star
27

raspberry-llm

Calling LLM APIs on a Raspberry Pi for lulz
Python
3
star
28

Statistical-Inference

This repository contains the lab assignments for the facilitation of John Hopkins University' Coursera MOOC on Statistical Inference.
R
3
star
29

kaggle_titanic

Code for Kaggle Titanic Challenge (and other learning)
HTML
3
star
30

Statistical-Learning

Stanford OpenX: Introduction to Statistical Learning
HTML
3
star
31

Data-Analysis-and-Statistical-Inference-Project

Coursera: Data Analysis & Statistical Inference Project (Feb 2014)
R
2
star
32

neural_networks_and_deep_learning

2
star
33

Twitter-SMA

Twitter Streaming and Analysis with Python and R
R
2
star
34

scratch

Jupyter Notebook
2
star
35

Getting-and-Cleaning-Data

Coursera: Getting and Cleaning Data (May 2014)
R
2
star
36

Computer-Science-and-Programming-In-Python

edX: Introduction to Computer Science and Programming in Python (July 2014)
Python
1
star
37

Misc

R
1
star
38

datagene

Jupyter Notebook
1
star
39

Interactive-Programming-in-Python

Coursera: Interactive Programming in Python (Apr 2014)
Python
1
star
40

R-Programming

Coursera: R Programming (May 2014)
R
1
star
41

Visualizations

Random Visualizations
R
1
star
42

json-to-utterances

Jupyter Notebook
1
star
43

DKSG-HOME

Sharing my R script used in the DKSG DataLearn for home
R
1
star
44

eugeneyan-comments

1
star
45

kaggle_otto

Code for Kaggle Otto Production Classification Challenge
R
1
star
46

Demand-Forecasting

Prototyping various forecasting techniques
R
1
star
47

Machine-Learning

Coursera: Machine Learning (Aug 2014)
MATLAB
1
star