• Stars
    star
    103
  • Rank 331,112 (Top 7 %)
  • Language
    Java
  • License
    MIT License
  • Created over 7 years ago
  • Updated over 1 year ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Package provides java implementation of reinforcement learning algorithms such Q-Learn, R-Learn, SARSA, Actor-Critic

java-reinforcement-learning

Package provides java implementation of reinforcement learning algorithms as described in the book "Reinforcement Learning: An Introduction" by Sutton

Build Status Coverage Status

Features

The following reinforcement learning are implemented:

  • R-Learn
  • Q-Learn
  • Q-Learn with eligibility trace
  • SARSA
  • SARSA with eligibility trace
  • Actor-Critic
  • Actor-Critic with eligibility trace

The package also support a number of action-selection strategy:

  • soft-max
  • epsilon-greedy
  • greedy
  • Gibbs-soft-max

Reinforcement Learning

Install

Add the following dependency to your POM file:

<dependency>
  <groupId>com.github.chen0040</groupId>
  <artifactId>java-reinforcement-learning</artifactId>
  <version>1.0.5</version>
</dependency>

Application Samples

The application sample of this library can be found in the following repositories:

Usage

Create Agent

An reinforcement agent, say, Q-Learn agent, can be created by the following java code:

import com.github.chen0040.rl.learning.qlearn.QAgent;

int stateCount = 100;
int actionCount = 10;
QAgent agent = new QAgent(stateCount, actionCount);

The agent created has a state map of 100 states, and 10 different actions for its selection.

For Q-Learn and SARSA, the eligibility trace lambda can be enabled by calling:

agent.enableEligibilityTrace(lambda)

Select Action

At each time step, a action can be selected by the agent, by calling:

int actionId = agent.selectAction().getIndex();

If you want to limits the number of possible action at each states (say the problem restrict the actions avaliable at different state), then call:

Set<Integer> actionsAvailableAtCurrentState = world.getActionsAvailable(agent);
int actionTaken = agent.selectAction(actionsAvailableAtCurrentState).getIndex();

The agent can also change to a different action-selection policy available in com.github.chen0040.rl.actionselection package, for example, the following code switch the action selection policy to soft-max:

agent.getLearner().setActionSelection(SoftMaxActionSelectionStrategy.class.getCanonicalName());

State-Action Update

Once the world state has been updated due to the agent's selected action, its internal state-action Q matrix will be updated by calling:

int newStateId = world.update(agent, actionTaken);
double reward = world.reward(agent);

agent.update(actionTaken, newStateId, reward);

Sample code

Sample code for R-Learn

import com.github.chen0040.rl.learning.rlearn.RAgent;

int stateCount = 100;
int actionCount = 10;
RAgent agent = new RAgent(stateCount, actionCount);

Random random = new Random();
agent.start(random.nextInt(stateCount));
for(int time=0; time < 1000; ++time){

 int actionId = agent.selectAction().getIndex();
 System.out.println("Agent does action-"+actionId);
 
 int newStateId = world.update(agent, actionId);
 double reward = world.reward(agent);

 System.out.println("Now the new state is " + newStateId);
 System.out.println("Agent receives Reward = "+reward);

 agent.update(actionId, newStateId, reward);
}

Alternatively, you can use RLearner if you want to learning after the episode:

class Move {
    int oldState;
    int newState;
    int action;
    double reward;
    
    public Move(int oldState, int action, int newState, double reward) {
        this.oldState = oldState;
        this.newState = newState;
        this.reward = reward;
        this.action = action;
    }
}

int stateCount = 100;
int actionCount = 10;
RLearner agent = new RLearner(stateCount, actionCount);

Random random = new Random();
int currentState = random.nextInt(stateCount));
List<TupleThree<Integer, Integer, Double>> moves = new ArrayList<>();
for(int time=0; time < 1000; ++time){

 int actionId = agent.selectAction(currentState).getIndex();
 System.out.println("Agent does action-"+actionId);
 
 int newStateId = world.update(agent, actionId);
 double reward = world.reward(agent);

 System.out.println("Now the new state is " + newStateId);
 System.out.println("Agent receives Reward = "+reward);
 int oldStateId = currentState;
 moves.add(new Move(oldStateId, actionId, newStateId, reward));
  currentState = newStateId;
}

for(int i=moves.size()-1; i >= 0; --i){
    Move move = moves.get(i);
    agent.update(move.oldState, move.action, move.newState, world.getActionsAvailableAtState(nextStateId), move.reward);
}

Sample code for Q-Learn

import com.github.chen0040.rl.learning.qlearn.QAgent;

int stateCount = 100;
int actionCount = 10;
QAgent agent = new QAgent(stateCount, actionCount);

Random random = new Random();
agent.start(random.nextInt(stateCount));
for(int time=0; time < 1000; ++time){

 int actionId = agent.selectAction().getIndex();
 System.out.println("Agent does action-"+actionId);
 
 int newStateId = world.update(agent, actionId);
 double reward = world.reward(agent);

 System.out.println("Now the new state is " + newStateId);
 System.out.println("Agent receives Reward = "+reward);

 agent.update(actionId, newStateId, reward);
}

Alternatively, you can use QLearner if you want to learning after the episode:

class Move {
    int oldState;
    int newState;
    int action;
    double reward;
    
    public Move(int oldState, int action, int newState, double reward) {
        this.oldState = oldState;
        this.newState = newState;
        this.reward = reward;
        this.action = action;
    }
}

int stateCount = 100;
int actionCount = 10;
QLearner agent = new QLearner(stateCount, actionCount);

Random random = new Random();
int currentState = random.nextInt(stateCount));
List<TupleThree<Integer, Integer, Double>> moves = new ArrayList<>();
for(int time=0; time < 1000; ++time){

 int actionId = agent.selectAction(currentState).getIndex();
 System.out.println("Agent does action-"+actionId);
 
 int newStateId = world.update(agent, actionId);
 double reward = world.reward(agent);

 System.out.println("Now the new state is " + newStateId);
 System.out.println("Agent receives Reward = "+reward);
 int oldStateId = currentState;
 moves.add(new Move(oldStateId, actionId, newStateId, reward));
  currentState = newStateId;
}

for(int i=moves.size()-1; i >= 0; --i){
    Move move = moves.get(i);
    agent.update(move.oldState, move.action, move.newState, move.reward);
}

Sample code for SARSA

import com.github.chen0040.rl.learning.sarsa.SarsaAgent;

int stateCount = 100;
int actionCount = 10;
SarsaAgent agent = new SarsaAgent(stateCount, actionCount);

Random random = new Random();
agent.start(random.nextInt(stateCount));
for(int time=0; time < 1000; ++time){

 int actionId = agent.selectAction().getIndex();
 System.out.println("Agent does action-"+actionId);
 
 int newStateId = world.update(agent, actionId);
 double reward = world.reward(agent);

 System.out.println("Now the new state is " + newStateId);
 System.out.println("Agent receives Reward = "+reward);

 agent.update(actionId, newStateId, reward);
}

Alternatively, you can use SarsaLearner if you want to learning after the episode:

class Move {
    int oldState;
    int newState;
    int action;
    double reward;
    
    public Move(int oldState, int action, int newState, double reward) {
        this.oldState = oldState;
        this.newState = newState;
        this.reward = reward;
        this.action = action;
    }
}

int stateCount = 100;
int actionCount = 10;
SarsaLearner agent = new SarsaLearner(stateCount, actionCount);

Random random = new Random();
int currentState = random.nextInt(stateCount));
List<TupleThree<Integer, Integer, Double>> moves = new ArrayList<>();
for(int time=0; time < 1000; ++time){

 int actionId = agent.selectAction(currentState).getIndex();
 System.out.println("Agent does action-"+actionId);
 
 int newStateId = world.update(agent, actionId);
 double reward = world.reward(agent);

 System.out.println("Now the new state is " + newStateId);
 System.out.println("Agent receives Reward = "+reward);
 int oldStateId = currentState;
 moves.add(new Move(oldStateId, actionId, newStateId, reward));
  currentState = newStateId;
}

for(int i=moves.size()-1; i >= 0; --i){
    Move next_move = moves.get(i);
    if(i != moves.size()-1) {
        next_move = moves.get(i+1);
    }
    Move current_move = moves.get(i);
    agent.update(current_move.oldState, current_move.action, current_move.newState, next_move.action, current_move.reward);
}

Sample code for Actor Critic Model

import com.github.chen0040.rl.learning.actorcritic.ActorCriticAgent;
import com.github.chen0040.rl.utils.Vec;

int stateCount = 100;
int actionCount = 10;
ActorCriticAgent agent = new ActorCriticAgent(stateCount, actionCount);
Vec stateValues = new Vec(stateCount);

Random random = new Random();
agent.start(random.nextInt(stateCount));
for(int time=0; time < 1000; ++time){

 int actionId = agent.selectAction().getIndex();
 System.out.println("Agent does action-"+actionId);
 
 int newStateId = world.update(agent, actionId);
 double reward = world.reward(agent);

 System.out.println("Now the new state is " + newStateId);
 System.out.println("Agent receives Reward = "+reward);

 
 System.out.println("World state values changed ...");
 for(int stateId = 0; stateId < stateCount; ++stateId){
    stateValues.set(stateId, random.nextDouble());
 }
    
 agent.update(actionId, newStateId, reward, stateValues);
}

Alternatively, you can use ActorCriticLearner if you want to learning after the episode:

class Move {
    int oldState;
    int newState;
    int action;
    double reward;
    
    public Move(int oldState, int action, int newState, double reward) {
        this.oldState = oldState;
        this.newState = newState;
        this.reward = reward;
        this.action = action;
    }
}

int stateCount = 100;
int actionCount = 10;
SarsaLearner agent = new SarsaLearner(stateCount, actionCount);

Random random = new Random();
int currentState = random.nextInt(stateCount));
List<TupleThree<Integer, Integer, Double>> moves = new ArrayList<>();
for(int time=0; time < 1000; ++time){

 int actionId = agent.selectAction(currentState).getIndex();
 System.out.println("Agent does action-"+actionId);
 
 int newStateId = world.update(agent, actionId);
 double reward = world.reward(agent);

 System.out.println("Now the new state is " + newStateId);
 System.out.println("Agent receives Reward = "+reward);
 int oldStateId = currentState;
 moves.add(new Move(oldStateId, actionId, newStateId, reward));
  currentState = newStateId;
}

for(int i=moves.size()-1; i >= 0; --i){
    Move next_move = moves.get(i);
    if(i != moves.size()-1) {
        next_move = moves.get(i+1);
    }
    Move current_move = moves.get(i);
    agent.update(current_move.oldState, current_move.action, current_move.newState, next_move.action, current_move.reward);
}

Save and Load RL models

To save the trained RL model (say QLeanrer):

QLearner learner = new QLearner(stateCount, actionCount);
train(learner);
String json = learner.toJson();

To load the trained RL model from json:

QLearner learner = QLearn.fromJson(json);

More Repositories

1

keras-anomaly-detection

Anomaly detection implemented in Keras
Python
364
star
2

keras-text-summarization

Text summarization using seq2seq in Keras
Python
283
star
3

keras-english-resume-parser-and-analyzer

keras project that parses and analyze english resumes
Python
250
star
4

keras-face

face detection, verification and recognition using Keras
Python
139
star
5

js-graph-algorithms

Package provides javascript implementation of algorithms for graph processing
JavaScript
135
star
6

keras-video-classifier

Keras implementation of video classifier
Python
112
star
7

cpp-spline

Package provides C++ implementation of spline interpolation
C++
96
star
8

keras-text-to-image

Translate text to image in Keras using GAN and Word2Vec as well as recurrent neural networks
Python
62
star
9

js-simulator

General-purpose discrete-event multiagent simulation library for agent-based modelling and simulation
JavaScript
59
star
10

lua-algorithms

Lua algorithms library that covers commonly used data structures and algorithms
Lua
59
star
11

lua-graph

Graph algorithms in lua
Lua
57
star
12

keras-chatbot-web-api

Simple keras chat bot using seq2seq model with Flask serving web
Python
53
star
13

mxnet-audio

Implementation of music genre classification, audio-to-vec, song recommender, and music search in mxnet
Python
51
star
14

cs-pdf-to-image

a simple library to convert pdf to image for .net
C#
40
star
15

keras-audio

keras project for audio deep learning
Python
39
star
16

cs-expert-system-shell

C# implementation of an expert system shell
C#
36
star
17

keras-recommender

Recommender built using keras
Python
35
star
18

keras-malicious-url-detector

Malicious URL detector using keras recurrent networks and scikit-learn classifiers
Python
34
star
19

js-regression

Package provides javascript implementation of linear regression and logistic regression
JavaScript
27
star
20

keras-sentiment-analysis-web-api

Web api built on flask for keras-based sentiment analysis using Word Embedding, RNN and CNN
Python
26
star
21

java-ssd-object-detection

Image SSD object detection in Java using Tensorrflow
Java
25
star
22

spring-websocket-android-client-demo

Demo on how to integrate spring websocket on the server with android client
Java
23
star
23

keras-question-and-answering-web-api

Question answering system developed using seq2seq and memory network model in Keras
Python
22
star
24

keras-fake-news-generator-and-detector

Fake news generator and detector using keras
Python
21
star
25

java-magento-client

Java client for communicating with Magento site
Java
21
star
26

spring-boot-spark-integration-demo

Demo on how to integrate Spring Data JPA, Apache Spark and GraphX with Java and Scala mixed codes
Java
18
star
27

keras-video-object-detector

Object detector in videos using keras and YOLO
Python
17
star
28

java-reinforcement-learning-flappy-bird

Demo of java-reinforcement-learning library using flappy bird
Java
16
star
29

keras-language-translator-web-api

A simple language translator implemented in Keras with Flask serving web
Python
15
star
30

cs-hidden-markov-models

HIdden Markov Models using C#
C#
14
star
31

keras-chinese-resume-parser-and-analyzer

keras project that parses and analyze chinese resumes
Python
13
star
32

java-dynamic-programming

Solving dynamic programming problems in Java
Java
13
star
33

spring-boot-excel-upload-demo

Demo project on how upload and process csv and excel file in the spring boot
Java
12
star
34

java-decision-forest

Package implements decision tree and isolation forest
Java
12
star
35

java-tensorflow-samples

Java sample codes on how to integrate with tensorflow
Java
12
star
36

mxnet-sentiment-analysis

Sentiment Analysis implemented using Gluon and MXNet
Python
11
star
37

keras-search-engine

A simple document and image search engine implemented in keras
Python
11
star
38

mxnet-recommender

Collaborative Filtering NN and CNN based recommender implemented with MXNet
Python
11
star
39

unity-tensorflow-samples

Unity project that loads pretrained tensorflow pb model files and use them to predict
Python
11
star
40

java-clustering

Package provides java implementation of various clustering algorithms
Java
11
star
41

java-audio-embedding

Audio classifier, encoder, and search engine in Java
Java
10
star
42

js-recommender

Package provides java implementation of content collaborative filtering for recommend-er system
JavaScript
10
star
43

pyalgs

Package pyalgs implements algorithms in Robert Sedgwick's Algorithms using Python
Python
10
star
44

spring-websocket-csharp-client-demo

Demo of connecting C# client to spring web application via websocket
Java
10
star
45

scrapy-projects

Projects using selenium, requests, bs4, and scrapy for web scraping on google images, google trends and others
Python
9
star
46

cs-moea

Multi-Objective Evolutionary Algorithms implemented in .NET
C#
9
star
47

java-lda

Package provides java implementation of the latent dirichlet allocation (LDA) for topic modelling
Java
9
star
48

java-local-outlier-factor

Package implements a number local outlier factor algorithms for outlier detection and finding anomalous data
Java
9
star
49

js-stats

Package provides the javascript implementation of various statistics and distribution
JavaScript
8
star
50

java-genetic-programming

Genetic-programming framework for various genetic programming paradigms such as linear genetic programming, tree genetic programming, gene expression programming, etc
Java
8
star
51

java-adaptive-resonance-theory

Package provides java implementation of algorithms in the field of adaptive resonance theory (ART)
Java
7
star
52

java-outliers

Package provide java implementation of outlier detection using normal distribution for multi-variate datasets
Java
7
star
53

java-basic-blockchain

Proof-of-concept blockchain implementation in Java
Java
7
star
54

cs-fuzzy-logic

Package provides C# implementation of fuzzy logic system
C#
6
star
55

spark-ml-genetic-programming

Package provides java implementation of big-data genetic programming for Apache Spark
Java
6
star
56

java-libsvm

Package provides the direct java conversion of the origin libsvm C codes as well as a number of adapter to make it easier to program with libsvm on Java
Java
6
star
57

java-reinforcement-learning-tic-tac-toe

Demo of reinforcement learning using tic-tac-toe
Java
6
star
58

spring-security-csrf-android-demo

Demo on how to communicate android with spring security and CSRF enabled
Java
6
star
59

java-ann-mlp

Package provides java implementation of multi-layer perceptron neural network with back-propagation learning algorithm
Java
6
star
60

java-text-embedding

Word embedding in Java
Java
5
star
61

java-data-frame

Package provides the core data frame implementation for numerical computation
Java
5
star
62

java-tensorflow-music

Music classification, music search, music recommender and music encoder implemented in Tensorflow and Java
Java
5
star
63

cs-feedback-control

A simple control system framework that provide tools for feedback controllers such as PID controller, kalman filters, fuzzy controller
C#
5
star
64

java-statistical-inference

Opinionated statistical inference engine with fluent api to make it easier for conducting statistical inference with little or no knowledge of statistical inference principles involved
Java
5
star
65

cs-tree-genetic-programming

tree-based genetic programming implemented using C#
C#
4
star
66

keras-gan-models

Some generative adversarial network models that I studied
Python
4
star
67

java-glm

Generalized linear models for regression and classification problems
Java
4
star
68

cs-ffmpeg-mp3-converter

Convert audio file of other formats to mp3 using ffmpeg in .NET
C#
4
star
69

php-magento2-api-extensions

Some useful Magento2 API extensions
PHP
4
star
70

cs-grammatical-evolution

Grammatical evolution implemented using C#
C#
4
star
71

keras-image-to-image

Transform one image to another image in Keras using GAN
Python
4
star
72

java-leetcode

My daily LeetCode solutions
Java
4
star
73

android-code-view

A code viewer with code syntax highlight for Android
Java
4
star
74

java-machine-learning-web-api

A simple machine learning web server that caters for small datasets
Java
4
star
75

java-regex-cultivator

Regex generator which use genetic programming evolve grok and and to automatically discover regex given a set of texts having similar structure
Java
4
star
76

spring-websocket-angular-4-demo

Demo on how to integrate spring websocket with angular 4 application
TypeScript
3
star
77

android-magento-client

android client for communicating with magento
Java
3
star
78

keras-timeseries-web-api

recurrent neural networks for timeseries prediction in Keras
Python
3
star
79

vagrant-magento-2.16

Vagrantfile for magento 2 and Ubuntu
ApacheConf
3
star
80

java-naive-bayes-classifier

Package provides java implementation of naive bayes classifier
Java
3
star
81

spring-boot-auth2-slingshot

The original spring-boot-slingshot project that is extended with Auth2 for login using Facebook and Google
Java
3
star
82

spark-opt-moea

Distributed Multi-Objective Evolutionary Computation Framework for Spark
Java
3
star
83

cpp-steering-behaviors

OpenGL Demo for Game Agent Steering + Flocking + Swarm Behaviors
C
3
star
84

js-svm

Package provides javascript implementation of support vector machines
JavaScript
3
star
85

cs-linear-genetic-programming

Linear Genetic Programming implemented in C#
C#
3
star
86

java-som

Package provides java implementation of self-organizing feature map (Kohonen map)
Java
3
star
87

spring-security-csrf-angular-4-demo

Demo on how to integrate angular 4 application with spring application that has spring security and CSRF enabled
Java
3
star
88

spring-websocket-java-client-demo

Demo on how to integrate spring websocket on the server with java client
Java
3
star
89

unity-magento-client

Magento client implemented in Unity3D
C#
3
star
90

cpp-mfc-fractal-art-iec-lgp

Interactive Evolutionary Computation for Fractal Arts using Linear Genetic Programming and MFC
C
2
star
91

mxnet-text-to-image

Text to Image translation using Generative Adversarial Network and MXNet
Python
2
star
92

cs-swarm-intelligence

Swam intelligence for numerical optimization implemented in .NET
C#
2
star
93

mxnet-vqa

Yet Another Visual Question Answering in MXNet
Python
2
star
94

keras-image-captioning

Image captioning using recurrent network and convolutional network in Keras
Python
2
star
95

cs-optimization-continuous-solutions

Local searches for continuous optimization implemented in C#
C#
2
star
96

spring-security-csrf-unity-client-demo

Java
2
star
97

cs-optimization-binary-solutions

Local search optimization for binary-coded solutions implemented in C#
C#
2
star
98

spring-boot-slingshot

slingshot project with spring boot and spring security and spring data jpa
Java
2
star
99

mxnet-image-to-image

Image to Image translation using MXNet and GAN
Python
2
star
100

cs-ipico-reader

C# IPICO Reader
C#
2
star