• Stars
    star
    234
  • Rank 170,665 (Top 4 %)
  • Language
    Python
  • Created over 3 years ago
  • Updated over 1 year ago

Reviews

There are no reviews yet. Be the first to send feedback to the community and the maintainers!

Repository Details

Swarm training framework using Haiku + JAX + Ray for layer parallel transformer language models on unreliable, heterogeneous nodes

Pipelined Swarm Training

Swarm training "framework" using Haiku + Jax + Ray.

Designed for training large language models in a model parallel fashion with unreliable, heterogeneous nodes. (eventually)

Look in swarm_run.py for an example of running a character transformer on enwik8.

TODOs

  • Forward passes
  • Backward passes with activation reconstruction
  • Run optimizer
  • Logging
  • Checkpointing
  • Actually do pipelining
  • fp16 with static loss scaling
  • Integer quantization for activations and gradients between layers
  • Get rid of pipeline stalls from running optimizer
  • Data parallelism with multiple nodes per layer and gradient/weight aggregation
  • Heterogeneous nodes with potentially multiple layers per node
  • Handle unbalanced and unreliable nodes (layerdrop)
  • Dynamic node addition
  • 1T or bust?