• Stars
    star
    410
  • Rank 105,468 (Top 3 %)
  • Language
    Jupyter Notebook
  • Created over 6 years ago
  • Updated almost 4 years ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

๐Ÿ“ƒ ๐–€๐–“๐–”๐–‹๐–‹๐–Ž๐–ˆ๐–Ž๐–†๐–‘ PyTorch Implementation of DA-RNN (arXiv:1704.02971)

PyTorch Implementation of DA-RNN

PRs Welcome contributions welcome HitCount Open In Colab

Get hands-on experience of implementation of RNN (LSTM) in Pytorch;
Get familiar with Finacial data with Deep Learning;

Stargazers over time

Stargazers over time

Table of Contents

Dataset

Download

NASDAQ 100 stock data

Description

This dataset is a subset of the full NASDAQ 100 stock dataset used in [1]. It includes 105 days' stock data starting from July 26, 2016 to December 22, 2016. Each day contains 390 data points except for 210 data points on November 25 and 180 data points on Decmber 22.

Some of the corporations under NASDAQ 100 are not included in this dataset because they have too much missing data. There are in total 81 major coporations in this dataset and we interpolate the missing data with linear interpolation.

In [1], the first 35,100 data points are used as the training set and the following 2,730 data points are used as the validation set. The last 2,730 data points are used as the test set.

Usage

Train

usage: main.py [-h] [--dataroot DATAROOT] [--batchsize BATCHSIZE]
               [--nhidden_encoder NHIDDEN_ENCODER]
               [--nhidden_decoder NHIDDEN_DECODER] [--ntimestep NTIMESTEP]
               [--epochs EPOCHS] [--lr LR]

PyTorch implementation of paper 'A Dual-Stage Attention-Based Recurrent Neural
Network for Time Series Prediction'

optional arguments:
  -h, --help            show this help message and exit
  --dataroot DATAROOT   path to dataset
  --batchsize BATCHSIZE
                        input batch size [128]
  --nhidden_encoder NHIDDEN_ENCODER
                        size of hidden states for the encoder m [64, 128]
  --nhidden_decoder NHIDDEN_DECODER
                        size of hidden states for the decoder p [64, 128]
  --ntimestep NTIMESTEP
                        the number of time steps in the window T [10]
  --epochs EPOCHS       number of epochs to train [10, 200, 500]
  --lr LR               learning rate [0.001] reduced by 0.1 after each 10000
                        iterations

An example of training process is as follows:

python3 main --lr 0.0001 --epochs 50

Result

Training process

Training Loss

Prediction

DA-RNN

In the paper "A Dual-Stage Attention-Based Recurrent Neural Network for Time Series Prediction".

They proposed a novel dual-stage attention-based recurrent neural network (DA-RNN) for time series prediction. In the first stage, an input attention mechanism is introduced to adaptively extract relevant driving series (a.k.a., input features) at each time step by referring to the previous encoder hidden state. In the second stage, a temporal attention mechanism is introduced to select relevant encoder hidden states across all time steps.

For the objective, a square loss is used. With these two attention mechanisms, the DA-RNN can adaptively select the most relevant input features and capture the long-term temporal dependencies of a time series. A graphical illustration of the proposed model is shown in Figure 1.

Figure 1: Graphical illustration of the dual-stage attention-based recurrent neural network.

The Dual-Stage Attention-Based RNN (a.k.a. DA-RNN) model belongs to the general class of Nonlinear Autoregressive Exogenous (NARX) models, which predict the current value of a time series based on historical values of this series plus the historical values of multiple exogenous time series.

LSTM

Recursive Neural Network model has been used in this paper. RNN models are powerful to exhibit quite sophisticated dynamic temporal structure for sequential data. RNN models come in many forms, one of which is the Long-Short Term Memory (LSTM) model that is widely applied in language models.

Attention Mechanism

Attention mechanism performs feature selection as the paper mentioned, the model can keep only the most useful information at each temporal stage.

Model

DA-RNN model includes two LSTM networks with attention mechanism (an encoder and a decoder).

In the encoder, they introduced a novel input attention mechanism that can adaptively select the relevant driving series. In the decoder, a temporal attention mechanism is used to automatically select relevant encoder hidden states across all time steps.

Experiments and Parameters Settings

NASDAQ 100 Stock dataset

In the NASDAQ 100 Stock dataset, we collected the stock prices of 81 major corporations under NASDAQ 100, which are used as the driving time series. The index value of the NASDAQ 100 is used as the target series. The frequency of the data collection is minute-by-minute. This data covers the period from July 26, 2016 to December 22, 2016, 105 days in total. Each day contains 390 data points from the opening to closing of the market except that there are 210 data points on November 25 and 180 data points on December 22. In our experiments, we use the ๏ฌrst 35,100 data points as the training set and the following 2,730 data points as the validation set. The last 2,730 data points are used as the test set. This dataset is publicly available and will be continuously enlarged to aid the research in this direction.

Training procedure & Parameters Settings

Category Description
Optimization method minibatch stochastic gradient descent (SGD) together with the Adam optimizer
number of time steps in the window T T = 10
size of hidden states for the encoder m m = p = 64, 128
size of hidden states for the decoder p m = p = 64, 128
Evaluation Metrics $$O(y_T , \hat{y_T} ) = \frac{1}{N} \sum \limits_{i=1}^{N} (y_T^i , \hat{y_T}^i)^2 $$

References

[1] Yao Qin, Dongjin Song, Haifeng Chen, Wei Cheng, Guofei Jiang, Garrison W. Cottrell. "A Dual-Stage Attention-Based Recurrent Neural Network for Time Series Prediction". arXiv preprint arXiv:1704.02971 (2017).
[2] Chandler Zuo. "A PyTorch Example to Use RNN for Financial Prediction". (2017).
[3] YitongCU. "Dual Staged Attention Model for Time Series prediction".
[4] Pytorch Forum. "Why 3d input tensors in LSTM?".

More Repositories

1

reinforcement-learning-stanford

๐Ÿ•น๏ธ CS234: Reinforcement Learning, Winter 2019 | YouTube videos ๐Ÿ‘‰
Python
297
star
2

machine-learning-uiuc

๐Ÿ–ฅ๏ธ CS446: Machine Learning in Spring 2018, University of Illinois at Urbana-Champaign
Python
264
star
3

CSAPP-Labs

๐Ÿ’ป Computer Systems: A Programmer's Perspective, Lab Assignments Solutions
C
162
star
4

image-similarity-using-deep-ranking

๐Ÿ–ผ๏ธ ๐–€๐–“๐–”๐–‹๐–‹๐–Ž๐–ˆ๐–Ž๐–†๐–‘ PyTorch implementation of "Learning Fine-grained Image Similarity with Deep Ranking" (arXiv:1404.4661)
Python
152
star
5

advanced-deep-learning-and-reinforcement-learning-deepmind

๐ŸŽฎ Advanced Deep Learning and Reinforcement Learning at UCL & DeepMind | YouTube videos ๐Ÿ‘‰
Jupyter Notebook
148
star
6

data-structures-ucb

๐ŸŒณ CS 61B: Data Structures in Spring 2018, University of California, Berkeley
Java
92
star
7

zhenye-na

๐Ÿงโ€โ™‚๏ธ
69
star
8

e2e-learning-self-driving-cars

๐Ÿš— ๐–€๐–“๐–”๐–‹๐–‹๐–Ž๐–ˆ๐–Ž๐–†๐–‘ PyTorch implementation of "End-to-End Learning for Self-Driving Cars" (arXiv:1604.07316) with Udacity's Simulation env
Jupyter Notebook
59
star
9

crnn-pytorch

โœ๏ธ Convolutional Recurrent Neural Network in Pytorch | Text Recognition
Jupyter Notebook
48
star
10

giligili

Go
31
star
11

computer-vision-uiuc

๐Ÿ–ผ๏ธ CS543 / ECE549: Computer Vision in Spring 2018, University of Illinois at Urbana-Champaign
MATLAB
27
star
12

gcn-spp

Shortest Path prediction using Graph Convolutional Networks
Jupyter Notebook
25
star
13

SQL-Exercises

๐Ÿ’พ WIKIBOOKS: SQL Exercises
PLpgSQL
22
star
14

neural-style-pytorch

๐Ÿ“„ ๐–€๐–“๐–”๐–‹๐–‹๐–Ž๐–ˆ๐–Ž๐–†๐–‘ PyTorch implementation of "A Neural Algorithm of Artistic Style" (arXiv:1508.06576)
Python
21
star
15

data-structures-uiuc

๐ŸŒณ CS225: Data Structures
C++
20
star
16

cs106b

:neckbeard: CS 106B: Programming Abstractions (C++) | Spring 2017
C++
19
star
17

database-systems-uiuc

๐Ÿ’พ CS411: Database Systems in Spring 2018, UIUC
TeX
19
star
18

leetcode

๐Ÿ‘จโ€๐Ÿ’ป This repository contains the solutions and explanations for algorithm problems in LeetCode, implemented by Python or Java. Code Skeletons are generated automatically via the `vscode-leetcode` plugin.
Python
19
star
19

pokemon-gan

๐Ÿผ Generating new Pokemons with Wasserstein DCGAN | TensorFlow Implementation
Python
18
star
20

lintcode

๐Ÿ‘จโ€๐Ÿ’ป This repository contains the solutions and explanations to the algorithm problems on LintCode. All are written in Python/Java/C++ and implemented by myself.
Python
17
star
21

coursera-ml

๐Ÿ’กThis repository contains all of the lecture exercises of Machine Learning course by Andrew Ng, Stanford University @ Coursera. All are implemented by myself and in MATLAB/Octave.
MATLAB
16
star
22

computational-advertising-uiuc

๐Ÿ’ธ CS498HS4: Computational Advertising in Fall 2018, UIUC
Python
11
star
23

aws-certs-cheatsheet

๐Ÿ’ฏ Cheatsheets for AWS Certified Exams - AWS Certified Solutions Architect Associate
SCSS
8
star
24

blog

๐Ÿ“” Technical blog
SCSS
6
star
25

algo-for-data-analytics

IE531: Algorithms for Data Analytics in 2018 Spring, UIUC
C
5
star
26

pan.go

๐Ÿ’พ A Tiny Golang based Distributed Cloud Storage Service | MySQL, Reids, RabbitMQ, Docker and Ceph
Go
4
star
27

viola-jones-algo

๐Ÿ‘จ๐Ÿ‘ฉ Viola Jones Face Detection
Python
3
star
28

marketplace

๐Ÿช Node.js based Marketplace Web Application
HTML
2
star
29

Pymelody

๐ŸŽถ Classical Music Generation with Machine Learning
Python
2
star
30

tiny-url

๐Ÿ”— URL shortening service built with Golang
Go
1
star
31

Deep-Learning-Specialization

โš›๏ธ Deep Learning Specialization by deeplearning.ai
Jupyter Notebook
1
star
32

practical-http

HTTP ๅ่ฎฎๅŽŸ็† + ๅฎž่ทต Web ๅผ€ๅ‘ๅทฅ็จ‹ๅธˆๅฟ…ๅญฆ
JavaScript
1
star
33

zhenye-na.github.io

JavaScript
1
star
34

learn.go

Go
1
star
35

analysis-of-network-data

IE532: Analysis of Network Data in 2017 Fall, UIUC
Jupyter Notebook
1
star