• Stars
    star
    264
  • Rank 155,103 (Top 4 %)
  • Language
    Python
  • License
    MIT License
  • Created almost 3 years ago
  • Updated 7 months ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

The official PyTorch implementation of the paper "SAITS: Self-Attention-based Imputation for Time Series". A fast and state-of-the-art (SOTA) deep-learning neural network model for efficient time-series imputation (impute multivariate incomplete time series containing NaN missing data/values with machine learning). https://arxiv.org/abs/2202.08516

SAITS Title

powered by Pytorch

The official code repository for the paper SAITS: Self-Attention-based Imputation for Time Series (preprint on arXiv is here), which has been accepted by the journal Expert Systems with Applications (ESWA) [2022 IF 8.665, CiteScore 12.2, JCR-Q1, CAS-Q1 (中科院-1区), CCF-C]. You may never hear of ESWA, while this journal was ranked 1st in Google Scholar under the top publications of Artificial Intelligence in 2016 (info source), and here is the current ranking list for your information.

SAITS is the first work applying pure self-attention and without any recursive design in the algorithm for general time series imputation. Basically you can take it as a validated framework for time series imputation. More generally, you can use it for sequence imputation. Therefore, you're welcome to modify SAITS for your own research purpose and domain applications. Of course, it probably needs a bit of modification in the model structure or loss functions for specific scenarios or data input.

🤗 Please cite SAITS in your publications if it helps with your research. Please star🌟 this repo to help others notice SAITS if you think it is useful. It really means a lot to my open-source work. Thank you! BTW, you may also like PyPOTS for easily modeling your partially-observed time-series datasets.

‼️Kind reminder: This document can help you solve many common questions, please read it before you run the code.

📣 Attention please:
SAITS now is available in PyPOTS, a Python toolbox for data mining on POTS (Partially-Observed Time Series). An example of training SAITS for imputing dataset PhysioNet-2012 is shown below. With PyPOTS, easy peasy! 😉

👉 Click here to see the example 👀
# Install PyPOTS first: pip install pypots==0.1.1
import numpy as np
from sklearn.preprocessing import StandardScaler
from pypots.data import load_specific_dataset, mcar, masked_fill
from pypots.imputation import SAITS
from pypots.utils.metrics import cal_mae
# Data preprocessing. Tedious, but PyPOTS can help. 🤓
data = load_specific_dataset('physionet_2012')  # For datasets in PyPOTS database, PyPOTS will automatically download and extract it.
X = data['X']
num_samples = len(X['RecordID'].unique())
X = X.drop(['RecordID', 'Time'], axis = 1)
X = StandardScaler().fit_transform(X.to_numpy())
X = X.reshape(num_samples, 48, -1)
X_intact, X, missing_mask, indicating_mask = mcar(X, 0.1) # hold out 10% observed values as ground truth
X = masked_fill(X, 1 - missing_mask, np.nan)
# Model training. This is PyPOTS showtime. 💪
saits = SAITS(n_steps=48, n_features=37, n_layers=2, d_model=256, d_inner=128, n_heads=4, d_k=64, d_v=64, dropout=0.1, epochs=10)
dataset = {"X": X}
saits.fit(dataset)  # train the model. Here I use the whole dataset as the training set, because ground truth is not visible to the model.
imputation = saits.impute(dataset)  # impute the originally-missing values and artificially-missing values
mae = cal_mae(imputation, X_intact, indicating_mask)  # calculate mean absolute error on the ground truth (artificially-missing values)

❖ Motivation and Performance

⦿ Motivation: SAITS is developed primarily to help overcome the drawbacks (slow speed, memory constraints, and compounding error) of RNN-based imputation models and to obtain the state-of-the-art (SOTA) imputation accuracy on partially-observed time series.

⦿ Performance: SAITS outperforms BRITS by 12% ∼ 38% in MAE (mean absolute error) and achieves 2.0 ∼ 2.6 times faster training speed. Furthermore, SAITS outperforms Transformer (trained by our joint-optimization approach) by 2% ∼ 19% in MAE with a more efficient model structure (to obtain comparable performance, SAITS needs only 15% ∼ 30% parameters of Transformer). Compared to another SOTA self-attention imputation model NRTSI, SAITS achieves 7% ∼ 39% smaller mean squared error (above 20% in nine out of sixteen cases), meanwhile, needs much fewer parameters and less imputation time in practice. Please refer to our full paper for more details about SAITS' performance.

❖ Brief Graphical Illustration of Our Methodology

Here we only show the two main components of our method: the joint-optimization training approach and SAITS structure. For the detailed description and explanation, please read our full paper.

Training approach

Fig. 1: Training approach

SAITS architecture

Fig. 2: SAITS structure

❖ Citing SAITS

If you find SAITS is helpful to your work, please cite our paper as below, ⭐️star this repository, and recommend it to others who you think may need it. 🤗 Thank you!

@article{DU2023SAITS,
title = {{SAITS: Self-Attention-based Imputation for Time Series}},
journal = {Expert Systems with Applications},
volume = {219},
pages = {119619},
year = {2023},
issn = {0957-4174},
doi = {https://doi.org/10.1016/j.eswa.2023.119619},
url = {https://www.sciencedirect.com/science/article/pii/S0957417423001203},
author = {Wenjie Du and David Cote and Yan Liu},
}

or

Wenjie Du, David Cote, and Yan Liu. SAITS: Self-Attention-based Imputation for Time Series. Expert Systems with Applications, 219:119619, 2023.

❖ Repository Structure

The implementation of SAITS is in dir modeling. We give configurations of our models in dir configs, provide the dataset links and preprocessing scripts in dir dataset_generating_scripts. Dir NNI_tuning contains the hyper-parameter searching configurations.

❖ Development Environment

All dependencies of our development environment are listed in file conda_env_dependencies.yml. You can quickly create a usable python environment with an anaconda command conda env create -f conda_env_dependencies.yml.

❖ Datasets

For datasets downloading and generating, please check out the scripts in dir dataset_generating_scripts.

❖ Quick Run

Generate the dataset you need first. To do so, please check out the generating scripts in dir dataset_generating_scripts.

After data generation, train and test your model, for example,

# create a dir to save logs and results
mkdir NIPS_results

# train a model
nohup python run_models.py \
    --config_path configs/PhysioNet2012_SAITS_best.ini \
    > NIPS_results/PhysioNet2012_SAITS_best.out &

# during training, you can run the blow command to read the training log
less NIPS_results/PhysioNet2012_SAITS_best.out

# after training, pick the best model and modify the path of the model for testing in the config file, then run the below command to test the model
python run_models.py \
    --config_path configs/PhysioNet2012_SAITS_best.ini \
    --test_mode

❗️Note that paths of datasets and saving dirs may be different on personal computers, please check them in the configuration files.

❖ Acknowledgments

Thanks to Ciena, Mitacs, and NSERC (Natural Sciences and Engineering Research Council of Canada) for funding support. Thanks to Ciena for providing computing resources. Thanks to all our reviewers for helping improve the quality of this paper. And thank you all for your attention to this work.

✨Stars/forks/issues/PRs are all welcome!

👏 Click to View Stargazers and Forkers:

Stargazers repo roster for @WenjieDu/SAITS

Forkers repo roster for @WenjieDu/SAITS

❖ Last but Not Least

If you have any additional questions or have interests in collaboration, please take a look at my GitHub profile and feel free to contact me 😃.

More Repositories

1

PyPOTS

A Python toolbox/library for reality-centric machine/deep learning and data mining on partially-observed time series with PyTorch, including SOTA neural network models for science analysis tasks of imputation, classification, clustering, forecasting & anomaly detection on incomplete (irregularly-sampled) multivariate TS with NaN missing values
Python
683
star
2

TSDB

Time Series Data Beans: a Python toolbox loads 169 public time-series datasets for machine learning/deep learning with a single line of code.
Python
114
star
3

Awesome_Imputation

Awesome Deep Learning Resources for Time-Series Imputation, including a must-read paper list about using deep learning neural networks to impute incomplete time series containing NaN missing values/data
Python
64
star
4

BrewPOTS

The tutorials for PyPOTS.
Jupyter Notebook
40
star
5

PyGrinder

PyGrinder grinds data beans into the incomplete by introducing missing values with different missing patterns.
Python
20
star
6

Google_Scholar_Badge_Generator

This repository helps you automatically generate citation badges of articles/profiles on Google Scholar. With GitHub actions, you can make yourself a GoogleScholar version of shields.io
Python
9
star
7

clickLikeInQzone

利用python & selenium实现爬虫在 qq 空间 自动 点赞 和 回复
Python
6
star
8

WenjieDu

6
star
9

eye_game

A python module for parsing human gaze direction
Python
5
star
10

DevNet

An implementation of Deviation Network with a case on the credit card fraud dataset.
Python
3
star
11

Spider_on_GitHub_Star_Fork

A spider crawls user information of stargazers and forkers of given repositories, then saves such information into a .csv file with pandas.
Python
3
star
12

PropBag

道具口袋: 这里存放一些有趣的小demo和小东西😁,欢迎来逛逛.
Python
3
star
13

MuLePOTS

2
star
14

WeChatAutoReply

使用itchat实现的微信自动回复脚本
Python
2
star