• Stars
    star
    2,015
  • Rank 22,972 (Top 0.5 %)
  • Language
    Python
  • License
    MIT License
  • Created about 7 years ago
  • Updated over 1 year ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Tensorflow implementation of DeepFM for CTR prediction.

tensorflow-DeepFM

This project includes a Tensorflow implementation of DeepFM [1].

NEWS

Usage

Input Format

This implementation requires the input data in the following format:

  • Xi: [[ind1_1, ind1_2, ...], [ind2_1, ind2_2, ...], ..., [indi_1, indi_2, ..., indi_j, ...], ...]
    • indi_j is the feature index of feature field j of sample i in the dataset
  • Xv: [[val1_1, val1_2, ...], [val2_1, val2_2, ...], ..., [vali_1, vali_2, ..., vali_j, ...], ...]
    • vali_j is the feature value of feature field j of sample i in the dataset
    • vali_j can be either binary (1/0, for binary/categorical features) or float (e.g., 10.24, for numerical features)
  • y: target of each sample in the dataset (1/0 for classification, numeric number for regression)

Please see example/DataReader.py an example how to prepare the data in required format for DeepFM.

Init and train a model

import tensorflow as tf
from sklearn.metrics import roc_auc_score

# params
dfm_params = {
    "use_fm": True,
    "use_deep": True,
    "embedding_size": 8,
    "dropout_fm": [1.0, 1.0],
    "deep_layers": [32, 32],
    "dropout_deep": [0.5, 0.5, 0.5],
    "deep_layers_activation": tf.nn.relu,
    "epoch": 30,
    "batch_size": 1024,
    "learning_rate": 0.001,
    "optimizer_type": "adam",
    "batch_norm": 1,
    "batch_norm_decay": 0.995,
    "l2_reg": 0.01,
    "verbose": True,
    "eval_metric": roc_auc_score,
    "random_seed": 2017
}

# prepare training and validation data in the required format
Xi_train, Xv_train, y_train = prepare(...)
Xi_valid, Xv_valid, y_valid = prepare(...)

# init a DeepFM model
dfm = DeepFM(**dfm_params)

# fit a DeepFM model
dfm.fit(Xi_train, Xv_train, y_train)

# make prediction
dfm.predict(Xi_valid, Xv_valid)

# evaluate a trained model
dfm.evaluate(Xi_valid, Xv_valid, y_valid)

You can use early_stopping in the training as follow

dfm.fit(Xi_train, Xv_train, y_train, Xi_valid, Xv_valid, y_valid, early_stopping=True)

You can refit the model on the whole training and validation set as follow

dfm.fit(Xi_train, Xv_train, y_train, Xi_valid, Xv_valid, y_valid, early_stopping=True, refit=True)

You can use the FM or DNN part only by setting the parameter use_fm or use_dnn to False.

Regression

This implementation also supports regression task. To use DeepFM for regression, you can set loss_type as mse. Accordingly, you should use eval_metric for regression, e.g., mse or mae.

Example

Folder example includes an example usage of DeepFM/FM/DNN models for Porto Seguro's Safe Driver Prediction competition on Kaggle.

Please download the data from the competition website and put them into the example/data folder.

To train DeepFM model for this dataset, run

$ cd example
$ python main.py

Please see example/DataReader.py how to parse the raw dataset into the required format for DeepFM.

Performance

DeepFM

dfm

FM

fm

DNN

dnn

Some tips

  • You should tune the parameters for each model in order to get reasonable performance.
  • You can also try to ensemble these models or ensemble them with other models (e.g., XGBoost or LightGBM).

Reference

[1] DeepFM: A Factorization-Machine based Neural Network for CTR Prediction, Huifeng Guo, Ruiming Tang, Yunming Yey, Zhenguo Li, Xiuqiang He.

Acknowledgments

This project gets inspirations from the following projects:

License

MIT

More Repositories

1

kaggle-CrowdFlower

1st Place Solution for CrowdFlower Product Search Results Relevance Competition on Kaggle.
C++
1,755
star
2

kaggle-HomeDepot

3rd Place Solution for HomeDepot Product Search Results Relevance Competition on Kaggle.
Python
464
star
3

pytorch-DRL

PyTorch implementations of various Deep Reinforcement Learning (DRL) algorithms for both single agent and multi-agent.
Python
405
star
4

tensorflow-XNN

4th Place Solution for Mercari Price Suggestion Competition on Kaggle using DeepFM variant.
Python
278
star
5

tensorflow-DSMM

Tensorflow implementations of various Deep Semantic Matching Models (DSMM).
Python
228
star
6

tensorflow-LTR

Tensorflow implementations of various Learning to Rank (LTR) algorithms.
Python
218
star
7

caffe-windows

Caffe Windows with realtime data augmentation
C++
88
star
8

word2vec_cbow

this is a high performance cuda porting of cbow model of word2vec
Cuda
43
star
9

Kaggle_Walmart-Recruiting-Store-Sales-Forecasting

R code for Kaggle's Walmart Recruiting - Store Sales Forecasting
R
41
star
10

batch_normalization

Batch Normalization Layer for Caffe
C++
35
star
11

Kaggle_The_Hunt_for_Prohibited_Content

4th Place Solution for The Hunt for Prohibited Content Competition on Kaggle (http://www.kaggle.com/c/avito-prohibited-content)
Python
28
star
12

tensorflow-ASP-MTL

A Tensorflow implementation of Adversarial Shared-Private Model for Multi-Task Learning and Transfer Learning.
Python
25
star
13

Kaggle_Loan_Default_Prediction

R code for Kaggle's Loan Default Prediction - Imperial College London challenge
R
22
star
14

Kaggle_Galaxy_Zoo

Python & Theano code for Kaggle's Galaxy Zoo - The Galaxy Challenge
Python
8
star
15

image-rotation-angle-estimation

Effective estimation of image rotation angle using spectral method
MATLAB
7
star
16

tensorflow-DTN

A Tensorflow implementation of Domain Transfer Network.
Python
7
star
17

Kaggle_Higgs_Boson_Machine_Learning_Challenge

R's GBM model for Higgs Boson Machine Learning Challenge
R
6
star
18

Long-Capital

Quant Trading with Microsoft Qlib (https://github.com/microsoft/qlib)
Python
6
star
19

Kaggle_Acquire_Valued_Shoppers_Challenge

Code for Kaggle's Acquire Valued Shoppers Challenge
Python
5
star
20

Kaggle_Greek_Media_Monitoring_Multilabel_Classification

Code for Kaggles' Greek Media Monitoring Multilabel Classification (WISE 2014)
MATLAB
5
star
21

Stanford_CS229_Note

A draft note for Stanford CS229 Machine Learning course
TeX
3
star
22

GLF_Features_for_Median_Filtering_Forensics

MATLAB Toolbox for GLF Features for Median Filtering Forensics
MATLAB
2
star