JAX memory profiling, natively
Stormlog now tracks XLA allocations for JAX workloads with the same workflow you already use for PyTorch and TensorFlow — jit, pmap, and sharding included, across CPU, GPU, and TPU.
- Profile jax.jit / XLA allocations through profile_context and the profile_function decorator
- jaxmemprof CLI for info, monitor, track, and diagnose sessions
- Multi-device aggregation across jax.sharding and jax.pmap on GPU and TPU
- OOM flight recorder, telemetry sinks, and rank-aware artifacts carried over from the core profiler
from stormlog.jax import JAXMemoryProfiler
profiler = JAXMemoryProfiler()
with profiler.profile_context("jitted_step"):
y = fast_training_step(x)
y.block_until_ready()
results = profiler.get_results()
print(f"Peak memory: {results.peak_memory_mb:.2f} MB")



