• Stars
    star
    126
  • Rank 284,543 (Top 6 %)
  • Language
    Python
  • License
    MIT License
  • Created over 2 years ago
  • Updated over 1 year ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Unofficial Gato: A Generalist Agent

Unofficial Gato: A Generalist Agent

[Deepmind Publication] [arXiv Paper]

This repository contains Deepmind's Gato architecture imitation in TensorFlow.

Since Deepmind only mentions parts of the architecture in its paper, We still don't know much about the model.
However, I believe the paper is enough to imitate the architecture, I'm trying to do that with the open source community's help.

Currently, the repository supports the following operations:

Action tokens are still a mystery in the paper, I need your help.

However, the repository lacks the following miscellaneous.

  • Datasets (most important, Issue: #1, ThomasRochefortB/torch-gato)
  • Pre-trained tokenizers (No longer required because of E2E model)
  • Training strategy (E2E, WIP)

But, you can still explore the basic architecture of the Gato based on the paper.

Usage

$ pip install gato-tf
import tensorflow as tf
from gato import Gato, GatoConfig

# Create model instance
config = GatoConfig.small()
gato = Gato(config)

# Fake inputs for Gato
input_dim = config.input_dim
input_ids = tf.concat([
  # ...
  # observation 1
  tf.random.uniform((1, 1, input_dim)),  # image patch 0
  tf.random.uniform((1, 1, input_dim)),  # image patch 1
  tf.random.uniform((1, 1, input_dim)),  # image patch 2
  # ...
  tf.random.uniform((1, 1, input_dim)),  # image patch 19
  tf.fill((1, 1, input_dim), value=0.25),  # continuous value
  tf.fill((1, 1, input_dim), value=624.0),  # discrete (actions, texts)

  # observation 2
  tf.random.uniform((1, 1, input_dim)),  # image patch 0
  tf.random.uniform((1, 1, input_dim)),  # image patch 1
  tf.random.uniform((1, 1, input_dim)),  # image patch 2
  # ...
  tf.random.uniform((1, 1, input_dim)),  # image patch 19
  tf.fill((1, 1, input_dim), value=0.12),  # continuous value
  tf.fill((1, 1, input_dim), value=295.0)  # discrete (actions, texts)
  # ...
], axis=1)
encoding = tf.constant([
  # 0 - image patch embedding
  # 1 - continuous value embedding
  # 2 - discrete embedding (actions, texts)
  [0, 0, 0, 0, 1, 2, 0, 0, 0, 0, 1, 2]
])
row_pos = (
  tf.constant([[0.00, 0.25, 0.50, 0.75, 0, 0, 0.00, 0.25, 0.50, 0.75, 0, 0]]),  # pos_from
  tf.constant([[0.25, 0.50, 0.75, 1.00, 0, 0, 0.25, 0.50, 0.75, 1.00, 0, 0]])   # pos_to
)
col_pos = (
  tf.constant([[0.00, 0.00, 0.00, 0.80, 0, 0, 0.00, 0.00, 0.00, 0.80, 0, 0]]),  # pos_from
  tf.constant([[0.20, 0.20, 0.20, 1.00, 0, 0, 0.20, 0.20, 0.20, 1.00, 0, 0]])   # pos_to
)
obs = (
  tf.constant([[ 0,  1,  2, 19, 20, 21,  0,  1,  2, 19, 20, 21]]),  # obs token
  tf.constant([[ 1,  1,  1,  1,  1,  0,  1,  1,  1,  1,  1,  0]])   # obs token masking (for action tokens)
)
hidden_states = gato((input_ids, (encoding, row_pos, col_pos), obs))

Dataset and Model Architecture

gato dataset and model architecture

Paper Reviews

Full Episode Sequence

gato dataset architecture

Architecture Variants

Appendix C.1. Transformer Hyperparameters

In the paper, Deepmind tested Gato with 3 architecture variants, 1.18B, 364M, and 79M.
I have named them as large(), baseline() and small() respectively in GatoConfig.

Hyperparameters Large(1.18B) Baseline(364M) Small(79M)
Transformer blocks 24 12 8
Attention heads 16 12 24
Layer width 2048 1536 768
Feedforward hidden size 8192 6144 3072
Key/value size 128 128 32

Residual Embedding

Appendix C.2. Embedding Function

There are no mentions that how many residual networks must be stacked for token embeddings.
Therefore, I remain configurable in GatoConfig.

Whatever how many residual layers are existing, full-preactivation is a key.

The blocks are consisted of:

  • Version 2 ResNet architecture (based on ResNet50V2)
  • GroupNorm (instead of LayerNorm)
  • GeLU (instead of ReLU)

Position Encodings

Appendix C.3. Position Encodings

Patch Position Encodings

Like Vision Transformer (ViT) by Google, Gato takes the input images as raster-ordered 16x16 patches.
Unlike the Vision Transformer model, however, Gato divides its patch encoding strategy into 2 phases, training and evaluation.

For high-performance computation in TensorFlow, I have used the following expressions.

$C$ and $R$ mean column and row-wise, and $F$ and $T$ mean from and to respectively.

$$ \begin{align} v^R_F &= \begin{bmatrix} 0 & 32 & 64 & 96 \end{bmatrix} \\ v^R_T &= \begin{bmatrix} 32 & 64 & 96 & 128 \end{bmatrix} \\ v^C_F &= \begin{bmatrix} 0 & 26 & 51 & 77 & 102 \end{bmatrix} \\ v^C_T &= \begin{bmatrix} 26 & 51 & 77 & 102 & 128 \end{bmatrix} \\ \\ P_R &= \begin{cases} \mathsf{if} \ \mathsf{training} & v^R_F + \mathsf{uniform}(v^R_T - v^R_F) \\ \mathsf{otherwise} & \mathsf{round}(\frac{v^R_F + v^R_T}{2}) \end{cases} \\ P_C &= \begin{cases} \mathsf{if} \ \mathsf{training} & v^C_F + \mathsf{uniform}(v^C_T - v^C_F) \\ \mathsf{otherwise} & \mathsf{round}(\frac{v^C_F + v^C_T}{2}) \end{cases} \\ \\ E^R_P &= P_R \cdot 1^{\mathsf{T}}_C \\ E^C_P &= 1^{\mathsf{T}}_R \cdot P_C \\ \\ \therefore E &= E_I + E^R_P + E^C_P \end{align} $$

Local Observation Position Encodings

In the definition of Appendix B., text tokens, image patch tokens, and discrete & continuous values are observation tokens
When Gato receives those values, they must be encoded with their own (local) time steps.

Requirements

pip install tensorflow>=2.11.0

Contributing

This repository is still a work in progress.
Currently, no downloads and no executables are provided.

I welcome many contributors who can help.

License

Licensed under the MIT license.

More Repositories

1

ProDisplayXDR-ScreenSaver

Pro Display XDR screensaver for macOS
Swift
16
star
2

homebridge-daelim-smarthome

Verified homebridge plugin for HomeKit-unsupported DL E&C apartments
TypeScript
11
star
3

Matchmaking

A simple Matchmaking API using Redis
Java
9
star
4

Avis-WatchDog

The plugin that is preventing Minecraft users using Hacked Clients.
Java
9
star
5

ipdetection4j

An API union of proxy detection services in Java
Java
6
star
6

aegis-viii

An implementation of SNN(3rd gen) with synaptic plasticity
Java
6
star
7

CoRT

CoRT: Contrastive Rhetorical Tagging - KISTI 2022 AI/ML Competition
Python
6
star
8

lion-tf

Lion - EvoLved Sign Momentum w/ New Optimizer API in TensorFlow 2.11+
Python
5
star
9

simple-neural-network

This nn is so simple, but an implementation of multiple perceptron
Java
5
star
10

3rd-generation-neural-network-showcase

Showcases of 3rd Generation Neural Network (AEGIS-VIII)
4
star
11

deeplx-lambda-proxy

DeepLX Proxy with multiple AWS Lambda
HCL
4
star
12

best-salary-calculator

Java
3
star
13

Leveled-Storage

데이터 최적화를 위해 깊이에 따른 저장방식을 따르는 스토리지
Java
3
star
14

simple-midi-trail

A simple MIDI trail GUI created with JUIKit
Java
2
star
15

fractal-clock

FractalClock that inspired from HackerPoet's FractalClock
Java
2
star
16

renderer

Static HTML code generator using annotation combinations
Java
2
star
17

Text-AnimationLib

The pure java library to make text animation easy.
Java
2
star
18

juikit

Java builder style UI library
Java
2
star
19

homebridge-coway

Homebridge plugin for Coway purifier devices
TypeScript
2
star
20

adam-lr-decay

Adam with Layer-wise LR Decay
Python
2
star
21

JSP-Commons

An common libraries for developing JSP project!
Java
1
star
22

FractalClock-ScreenSaver

Fractal Clock screensaver for macOS
Swift
1
star
23

deeplex

DeepL Ex: Unlimited Free DeepL Translation with Glossaries
Python
1
star