• Stars
    star
    107
  • Rank 314,079 (Top 7 %)
  • Language
    Python
  • License
    Apache License 2.0
  • Created almost 3 years ago
  • Updated 11 months ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Deep reinforcement learning library built on top of Neural Network Libraries

License Build status Documentation Status Doc style

Deep Reinforcement Learning Library built on top of Neural Network Libraries

nnablaRL is a deep reinforcement learning library built on top of Neural Network Libraries that is intended to be used for research, development and production.

Installation

Installing nnablaRL is easy!

$ pip install nnabla-rl

nnablaRL only supports Python version >= 3.7 and nnabla version >= 1.17.

Enabling GPU accelaration (Optional)

nnablaRL algorithms run on CPU by default. To run the algorithm on GPU, first install nnabla-ext-cuda as follows. (Replace [cuda-version] depending on the CUDA version installed on your machine.)

$ pip install nnabla-ext-cuda[cuda-version]
# Example installation. Supposing CUDA 11.0 is installed on your machine.
$ pip install nnabla-ext-cuda110

After installing nnabla-ext-cuda, set the gpu id to run the algorithm on through algorithm's configuration.

import nnabla_rl.algorithms as A

config = A.DQNConfig(gpu_id=0) # Use gpu 0. If negative, will run on CPU.
dqn = A.DQN(env, config=config)
...

Features

Friendly API

nnablaRL has friendly Python APIs which enables to start training with only 3 lines of python code. (NOTE: Below code will run on CPU. See the above instruction to run on GPU.)

import nnabla_rl.algorithms as A
from nnabla_rl.utils.reproductions import build_classic_control_env

# Prerequisite: 
# Run below to enable rendering!
# $ pip install nnabla-rl[render]
env = build_classic_control_env("Pendulum-v1", render=True) # 1
ddpg = A.DDPG(env, config=A.DDPGConfig(start_timesteps=200))  # 2
ddpg.train(env)  # 3

To get more details about nnablaRL, see documentation and examples.

Many builtin algorithms

Most of famous/SOTA deep reinforcement learning algorithms, such as DQN, SAC, BCQ, GAIL, etc., are implemented in nnablaRL. Implemented algorithms are carefully tested and evaluated. You can easily start training your agent using these verified implementations.

For the list of implemented algorithms see here.

You can also find the reproduction and evaluation results of each algorithm here.
Note that you may not get completely the same results when running the reproduction code on your computer. The result may slightly change depending on your machine, nnabla/nnabla-rl's package version, etc.

Seamless switching of online and offline training

In reinforcement learning, there are two main training procedures, online and offline, to train the agent. Online training is a training procedure that executes both data collection and network update alternately. Conversely, offline training is a training procedure that updates the network using only existing data. With nnablaRL, you can switch these two training procedures seamlessly. For example, as shown below, you can easily train a robot's controller online using simulated environment and finetune it offline with real robot dataset.

import nnabla_rl
import nnabla_rl.algorithms as A

simulator = get_simulator() # This is just an example. Assuming that simulator exists
dqn = A.DQN(simulator)
# train online for 1M iterations
dqn.train_online(simulator, total_iterations=1000000)

real_data = get_real_robot_data() # This is also an example. Assuming that you have real robot data
# fine tune the agent offline for 10k iterations using real data
dqn.train_offline(real_data, total_iterations=10000)

Visualization of training graph and training progress

nnablaRL supports visualization of training graphs and training progresses with nnabla-browser!

import gym

import nnabla_rl.algorithms as A
import nnabla_rl.hooks as H
import nnabla_rl.writers as W
from nnabla_rl.utils.evaluator import EpisodicEvaluator

# save training computational graph
training_graph_hook = H.TrainingGraphHook(outdir="test")

# evaluation hook with nnabla's Monitor
eval_env = gym.make("Pendulum-v0")
evaluator = EpisodicEvaluator(run_per_evaluation=10)
evaluation_hook = H.EvaluationHook(
    eval_env,
    evaluator,
    timing=10,
    writer=W.MonitorWriter(outdir="test", file_prefix='evaluation_result'),
)

env = gym.make("Pendulum-v0")
sac = A.SAC(env)
sac.set_hooks([training_graph_hook, evaluation_hook])

sac.train_online(env, total_iterations=100)

training-graph-visualization

training-status-visualization

Getting started

Try below interactive demos to get started.
You can run it directly on Colab from the links in the table below.

Title Notebook Target RL task
Simple reinforcement learning training to get started Open In Colab Pendulum
Learn how to use training algorithms Open In Colab Pendulum
Learn how to use customized network model for training Open In Colab Mountain car
Learn how to use different network solver for training Open In Colab Pendulum
Learn how to use different replay buffer for training Open In Colab Pendulum
Learn how to use your own environment for training Open In Colab Customized environment
Atari game training example Open In Colab Atari games

Documentation

Full documentation is here.

Contribution guide

Any kind of contribution to nnablaRL is welcome! See the contribution guide for details.

License

nnablaRL is provided under the Apache License Version 2.0 license.

More Repositories

1

sonyflake

A distributed unique ID generator inspired by Twitter's Snowflake
Go
3,484
star
2

nnabla

Neural Network Libraries
Python
2,634
star
3

gobreaker

Circuit Breaker implemented in Go
Go
2,606
star
4

flutter-embedded-linux

Embedded Linux embedding for Flutter
C++
995
star
5

v8eval

Multi-language bindings to JavaScript engine V8
C++
399
star
6

flutter-elinux

Flutter tools for embedded Linux (eLinux)
Dart
392
star
7

ai-research-code

Python
316
star
8

nnabla-examples

Neural Network Libraries https://nnabla.org/ - Examples
Python
280
star
9

model_optimization

Model Compression Toolkit (MCT) is an open source project for neural network model optimization under efficient, constrained hardware. This project provides researchers, developers, and engineers advanced quantization and compression tools for deploying state-of-the-art neural networks.
Python
250
star
10

easyhttpcpp

A cross-platform HTTP client library with a focus on usability and speed
C++
152
star
11

sqvae

Pytorch implementation of stochastically quantized variational autoencoder (SQ-VAE)
Python
132
star
12

mapray-js

JavaScript library for Interactive high quality 3D globes and maps in the browser
TypeScript
118
star
13

nmos-cpp

An NMOS (Networked Media Open Specifications) Registry and Node in C++ (IS-04, IS-05)
C++
113
star
14

nnabla-ext-cuda

A CUDA Extension of Neural Network Libraries
Cuda
89
star
15

DiffRoll

PyTorch implementation of DiffRoll, a diffusion-based generative automatic music transcription (AMT) model
Jupyter Notebook
62
star
16

meta-flutter

Yocto recipes for Flutter Engine and custom embedders
BitBake
61
star
17

FxNorm-automix

FxNorm-Automix - Implementation of automatic music mixing systems. We show how we can use wet music data and repurpose it to train a fully automatic mixing system
Python
51
star
18

creativeai

CSS
49
star
19

appsync-client-go

AWS AppSync golang client library
Go
46
star
20

nnabla-nas

Neural Architecture Search for Neural Network Libraries
Python
44
star
21

flutter-elinux-plugins

Flutter plugins for embedded Linux (eLinux)
C++
43
star
22

nnabla-c-runtime

Neural Network Libraries https://nnabla.org/ - C Runtime
C
38
star
23

huis-ui-creator

JavaScript
38
star
24

NDJIR

NDJIR: Neural Direct and Joint Inverse Rendering for Geometry, Lights, and Materials of Real Object
Python
36
star
25

pyIEOE

Python
29
star
26

nmos-js

An NMOS (Networked Media Open Specifications) Client in Javascript (IS-04, IS-05)
JavaScript
27
star
27

openocd-nuttx

Fork of OpenOCD with NuttX thread support.
C
25
star
28

CLIPSep

Python
23
star
29

pdaf-library

C
22
star
30

cdp-js

Libraries/SDK modules for multi-platform application development
TypeScript
20
star
31

polar-densification

Python
17
star
32

cordova-plugin-cdp-nativebridge

JavaScript
16
star
33

audio-visual-seld-dcase2023

Baseline method for audio-visual sound event localization and detection task of DCASE 2023 challenge
Python
16
star
34

generator-cordova-plugin-devbed

JavaScript
14
star
35

nnc-plugin

Plugins for Neural Network Console (https://dl.sony.com/).
Python
14
star
36

dolp-colorconstancy

Python
11
star
37

typescript-fsa-redux-middleware

Fluent syntax for defining typesafe Redux vanilla middlewares on top of typescript-fsa.
TypeScript
9
star
38

cdn-purge-control-php

Multi CDN purge control library for PHP
PHP
8
star
39

micro-notifier

Simplified Pusher Clone
Go
8
star
40

nnabla-browser

Visualization toolkit for Neural Network Libraries
TypeScript
8
star
41

isren

JavaScript
8
star
42

pixel-guided-diffusion

Fine-grained Image Editing by Pixel-wise Guidance Using Diffusion Models
Python
8
star
43

smarttennissensorsdk

The Smart Tennis Sensor plugs into the end of a tennis racket and records data about all the shots you make throughout a game or practice. With the SDK, you can develop apps for analyzing and presenting that data in real-time.
Java
8
star
44

cdp-cli

Command line tools for generating start point of multi-platform application development (Details: see cdp-js repository)
HTML
7
star
45

timbre-trap

Code for the paper "Timbre-Trap: A Low-Resource Framework for Instrument-Agnostic Music Transcription"
Python
7
star
46

mct_quantizers

Python
6
star
47

aibo-development-tutorial

6
star
48

custom_layers

Python
5
star
49

smarttennissensormp4meta

Java
4
star
50

fp-diffusion

Jupyter Notebook
3
star
51

node-win-usbdev

C++
3
star
52

evsCluster

Python scripts to process EVS (Event-based vision sensor) data
Python
3
star
53

Instruct3Dto3D-doc

Official documentation of Instruct 3D-to-3D
HTML
2
star
54

nnabla-js

TypeScript
1
star
55

nnabla-doc

1
star