Skip to main content

JAX/MJX Training

JAX/MJX is the GPU-accelerated training backend for Mesozoic Labs. It runs thousands of parallel environments on a single GPU using MuJoCo MJX and trains with a pure-JAX PPO implementation built on Flax and Optax.

Overview

The JAX backend provides:

  • Massive parallelism -- 2,048-8,192 environments running simultaneously via jax.vmap
  • 10-100x speedup over the SB3/CPU path for large-scale training
  • GPU/TPU support -- works on NVIDIA GPUs and Google Cloud TPUs
  • Same curriculum -- identical 3-stage curriculum learning as the SB3 path
  • All three species -- T-Rex, Velociraptor, and Brachiosaurus

When to Use JAX vs SB3

SB3 (CPU)JAX/MJX (GPU)
Installpip install -e ".[train]"pip install -e ".[jax]"
HardwareAny CPUNVIDIA GPU or TPU
Parallelism4-32 envs (SubprocVecEnv)2,048-8,192 envs (jax.vmap)
AlgorithmPPO or SACPPO (JAX-native)
Best forQuick experiments, debugging, no-GPU setupsLarge-scale training, hyperparameter sweeps

Both backends share the same MJCF model files, TOML stage configs, and reward logic.

Installation

Install the JAX optional dependencies:

pip install -e ".[jax]"

This installs mujoco-mjx, jax[cuda12], flax, and optax. The JAX backend uses lazy imports, so the rest of the codebase works fine without these packages installed.

Basic Usage

Python API

from environments.shared.jax_training import train_jax

# Train T-Rex Stage 1 (balance) with 2048 parallel environments
train_jax(species="trex", stage=1, num_envs=2048, seed=42)

CLI

# Single stage
python -m environments.shared.jax_training --species trex --stage 1

# Full 3-stage curriculum
python -m environments.shared.jax_training --species trex --curriculum

Colab Notebook

The notebooks/jax_training.ipynb notebook supports all three species. Set the SPECIES variable at the top of the notebook:

SPECIES = "trex"  # or "velociraptor" or "brachiosaurus"

The notebook handles dependency installation, GPU/TPU detection, and curriculum training automatically.

Architecture

The JAX path mirrors the SB3 path but replaces Python-level loops with JIT-compiled JAX functions:

MJXDinoEnv(species="trex", stage=1)
+-- mjx.put_model(mj_model) # GPU-resident physics model
+-- jax.vmap(step_fn) # 2048+ parallel environments
+-- JAX PPO (Flax+Optax) # Pure-JAX policy + optimizer
+-- Running-mean norm # JAX-based observation normalization

Key Components

ModuleDescription
environments/shared/mjx_env.pyJAX-native batched environment with functional step/reset
environments/shared/jax_ppo.pyFlax ActorCritic network with PPO loss, GAE, and update functions
environments/shared/jax_trainer.pyHigh-level training loop with hooks for logging and checkpointing
environments/shared/jax_training.pyEntry point that loads TOML configs and runs the training pipeline
environments/shared/jax_normalization.pyRunning-mean observation normalization (equivalent to SB3's VecNormalize)
environments/shared/jax_curriculum.pyCurriculum stage management with threshold-based advancement
environments/shared/jax_eval.pyPolicy evaluation, metric collection, and video generation
environments/shared/jax_viz.pyTrajectory rendering and visualization tools

Dual-Backend Design

The SB3 and JAX paths share:

  • MJCF model files (*.xml) -- identical physics models, no changes needed
  • Reward logic -- extracted into pure functions that work with both NumPy and JAX arrays
  • Stage configs (configs/*/stage*.toml) -- same TOML files drive both backends
  • Evaluation rendering -- CPU MuJoCo rendering for both (MJX has no native renderer)

They differ in:

  • Environment wrapper -- Gymnasium step()/reset() vs. JAX functional step_fn/reset_fn
  • Training loop -- SB3 callbacks vs. JIT-compiled JAX rollout+update
  • Parallelism -- SubprocVecEnv (4-32 envs) vs. jax.vmap (2,048-8,192 envs)

PPO Hyperparameters (JAX)

The JAX PPO uses the same hyperparameter progression as the SB3 path, loaded from the per-species TOML configs:

ParameterStage 1Stage 2Stage 3Description
learning_rate3e-41e-45e-5Network learning rate
num_envs204820482048Parallel environments (via jax.vmap)
rollout_len646464Steps per rollout before update
num_minibatches444Minibatches per update epoch
update_epochs101010Epochs per PPO update
gamma0.990.990.995Discount factor
gae_lambda0.950.950.95GAE lambda
clip_range0.20.20.1PPO clip range
ent_coef0.010.010.001Entropy coefficient
vf_coef0.50.50.5Value function coefficient

With 2,048 environments and a rollout length of 64, each update uses 131,072 transitions -- far more data per update than the SB3 path.

3-Stage Curriculum

The JAX path follows the same curriculum as SB3:

  1. Stage 1 -- Balance: Stand upright without falling (forward_vel_weight=0, high alive_bonus)
  2. Stage 2 -- Locomotion: Walk and run forward (increase forward_vel_weight, add gait rewards)
  3. Stage 3 -- Behavior: Species-specific task (strike for Velociraptor, bite for T-Rex, food reach for Brachiosaurus)

Stage transitions use the same threshold-based advancement as the SB3 path. The jax_curriculum.py module checks evaluation metrics against the TOML-defined thresholds and advances when criteria are met for consecutive evaluations.

W&B Integration

Enable Weights & Biases logging by passing wandb_project:

train_jax(
species="trex",
stage=1,
num_envs=2048,
wandb_project="mesozoic-labs",
)

The JAX trainer logs per-update metrics (reward, episode length, policy loss, value loss, entropy) to W&B via the built-in WandbHook.

Checkpointing

Model parameters, optimizer state, and running normalization statistics are saved periodically during training. Checkpoints use JAX-compatible serialization so training can be resumed from any saved state.

# Resume from a checkpoint
train_jax(
species="trex",
stage=2,
checkpoint_dir="results/trex/jax/stage1",
)

GPU Memory Requirements

The number of parallel environments is the main factor in GPU memory usage:

EnvironmentsApproximate VRAMRecommended GPU
512~2 GBT4 (16 GB)
2,048~6 GBT4, V100, A100
4,096~12 GBV100 (32 GB), A100
8,192~22 GBA100 (40/80 GB)

The default of 2,048 environments works on most modern GPUs. Reduce num_envs if you encounter out-of-memory errors.

JIT Compilation

The first training step triggers JAX's JIT compilation, which takes 1-3 minutes depending on the model complexity and GPU. This is a one-time cost -- subsequent steps run at full speed. This is expected behavior and not an error.

Vertex AI with JAX

For cloud training with the JAX backend on Vertex AI, use an A100 GPU machine type:

job = aiplatform.CustomJob(
display_name="trex-jax-curriculum",
worker_pool_specs=[
{
"machine_spec": {
"machine_type": "a2-highgpu-1g",
"accelerator_type": "NVIDIA_TESLA_A100",
"accelerator_count": 1,
},
"replica_count": 1,
"container_spec": {
"image_uri": IMAGE_URI,
"command": ["python"],
"args": [
"-m", "environments.shared.jax_training",
"--species", "trex",
"--curriculum",
"--num-envs", "4096",
],
},
}
],
)
job.run(sync=False)

See Training on Vertex AI for full setup instructions.

Comparison with SB3 Results

The JAX backend trains significantly faster in wall-clock time due to massive parallelism. Both backends converge to similar final performance since they use the same reward functions and curriculum stages.

BackendEnvsSteps/secStage 1 Wall Time
SB3 (CPU, 4 envs)4~2,000~3 hours
JAX (GPU, 2048 envs)2,048~500,000~5 minutes

Times are approximate and vary by hardware.