GyroJAX¶
GyroJAX is a GPU-accelerated gyrokinetic particle-in-cell (PIC) simulation code for plasma turbulence, written entirely in JAX. It simulates ion temperature gradient (ITG), kinetic ballooning mode (KBM), and trapped electron mode (TEM) instabilities in tokamak flux-tube geometry — fully JIT-compiled, benchmarked against GENE/GX, and designed for research-grade plasma physics.
Key Features¶
- δf PIC gyrokinetic simulation in field-aligned flux-tube geometry (ψ, θ, α)
- Electrostatic + electromagnetic (β > 0) modes via FFT Poisson + Ampere solver
- Drift-kinetic electron model for TEM instability simulation
- Fused RK4 push+weights integrator — 3.78× GPU speedup over split integration
- Physics fixes: symmetric gyroaveraging scatter, radially-resolved g^αα Poisson
- Stability controls: absorbing wall BC, zonal-preserving weight spread, pullback transformation
- float32 throughout — fits 500k particles comfortably in 8GB VRAM
- Multi-GPU capable via JAX
pmap
Performance¶
| Configuration | Throughput |
|---|---|
| 100k–500k particles (RTX 3070 Ti Laptop, 8GB) | ~27 steps/sec |
| Fused RK4 vs split integration | 3.78× speedup |
| Memory (500k particles, float32) | < 8GB VRAM |
Quick Install¶
git clone https://github.com/runlaiagent/GyroJAX.git
cd GyroJAX
pip install -e .
Then run a basic ITG simulation:
import jax
from gyrojax.simulation_fa import SimConfigFA, run_simulation_fa
cfg = SimConfigFA(N_particles=50_000, n_steps=200, R0_over_LT=6.9)
diags, state, phi, geom = run_simulation_fa(cfg, key=jax.random.PRNGKey(42))
print(f"Final phi_max: {diags[-1].phi_max:.3e}")
Validated Benchmarks¶
GyroJAX has been validated against published reference results:
- Dimits shift threshold R/LT ≈ 6.0 (Dimits et al. 2000)
- CBC linear growth rate γ = 0.172 vti/R₀ (1.9% error vs GENE/GX)
- KBM β_crit ≈ 0.010–0.012 (Pueschel et al. 2008)
- TEM γ > 0 at R₀/LTe = 9.0, stable at 5.0