• Stars
    star
    101
  • Rank 338,166 (Top 7 %)
  • Language
    Rust
  • Created almost 8 years ago
  • Updated about 2 years ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

A deep learning library for rust

Alumina

An experimental deep learning library written in pure rust. Breakage expected on each release in the short term. See mnist.rs in examples or Rusty_SR for usage samples.

Overview

The key types are Node and Ops which are Rc-like references to components of a shared mutable Graph, which is extended gradually with new tensors and operations via construction functions. Facilities for reverse-mode automatic differentiation are included in operations, extending the graph as necessary. Typical graph construction and differentiation shown below:

// 1. Build a MLP neural net graph - 98% @ 10 epochs
let input = Node::new(&[-1, 28, 28, 1]).set_name("input");
let labels = Node::new(&[-1, 10]).set_name("labels");

let layer1 = elu(affine(&input, 256, msra(1.0))).set_name("layer1");
let layer2 = elu(affine(&layer1, 256, msra(1.0))).set_name("layer2");
let logits = linear(&layer2, 10, msra(1.0)).set_name("logits");

let training_loss = add(
  reduce_sum(softmax_cross_entropy(&logits, &labels, -1), &[], false).set_name("loss"),
  scale(l2(logits.graph().nodes_tagged(NodeTag::Parameter)), 1e-3).set_name("regularisation"),
)
.set_name("training_loss");
let accuracy = equal(argmax(&logits, -1), argmax(&labels, -1)).set_name("accuracy");

let parameters = accuracy.graph().nodes_tagged(NodeTag::Parameter);

let grads = Grad::of(training_loss).wrt(parameters).build()?;

Current work is focused on improving the high level graph construction API, and better supporting dynamic/define-by-run graphs.

Contributions

Issues are a great place for discussion, problems, requests.

Documentation

Patchy until the library API experimentation ends, particularly until the graph construction API finalised.

Progress

  • Computation hypergraph
  • NN
    • Dense Connection and Bias operations
    • N-dimensional Convolution
      • Arbitrary padding
      • Strides
      • Reflection padding
    • Categorical Cross Entropy
    • Binary Cross Entropy
  • Boolean
    • Equal
    • Greater_Equal
    • Greater_Than
    • Less_Equal
    • Less_Than
    • Not
  • Elementwise
    • Abs
    • Ceil
    • Cos
    • Div
    • Elu
    • Exp
    • Floor
    • Identity
    • Leaky_relu
    • Ln
    • Logistic
    • Max
    • Min
    • Mul
    • Negative
    • Offset
    • Reciprocal
    • Relu
    • Robust
    • Round
    • Scale
    • Sign
    • Sin
    • SoftPlus
    • SoftSign
    • Sqr
    • Sqrt
    • Srgb
    • Subtract
    • Tanh
  • Grad
    • Stop_grad
  • Manip
    • Concat
    • Slice
    • Permute_axes
    • Expand_dims
    • Remove_dims
  • Math
    • Argmax
    • Broadcast
  • Pooling
    • N-dimensional Avg_Pool
    • Max pool
    • N-dimensional spaxel shuffling for "Sub-pixel Convolution"
    • N-dimensional Linear-Interpolation
    • Global Pooling
  • Reduce
    • Reduce_Prod
    • Reduce_Sum
  • Regularisation
    • L1
    • L2
    • Hoyer_squared
    • Robust
  • Shapes
    • Shape inference and constraint propagation
  • Data Loading
    • Mnist
    • Cifar
    • Image Folders
    • Imagenet (ILSVRC)
  • SGD
  • RMSProp
  • ADAM
  • Basic numerical tests
  • Limit Optimiser evaluation batch size to stay within memory limits
  • Selectively disable calculation of forward values, node derivatives and parameter derivatives
  • Builder patterns for operation contruction
  • Split Graph struct into mutable GraphBuilder and immutable Sub-Graphs
    • Replace 'accidentally quadratic' graph algorithms
    • Replace up-front allocation with Sub-Graph optimised allocation/deallocation patterns based on liveness analysis of nodes
  • Overhaul data ingestion, particularly buffering input processing/reads.
  • Move tensor format to bluss' ndarray
  • Improve naming inter/intra-library consistancy
  • Operator overloading for simple ops
  • Complete documentation
  • Reduce ability to express illegal states in API
  • Move from panics to error-chain
  • Move from error-chain to thiserror
  • Guard unsafe code rigourously
  • Comprehensive tests

Distant

  • Optionally typed tensors
  • Arrayfire as an option for sgemm on APUs
  • Graph optimisation passes and inplace operations
  • Support for both dynamic and static graphs
    • RNNs

License

MIT