Quick Start¶
Requirements¶
- Python 3.11+
- JAX 0.4.x+ (with CUDA support for GPU)
- NVIDIA GPU (8GB+ VRAM recommended for 500k particles)
- NumPy, SciPy, Matplotlib
Installation¶
git clone https://github.com/runlaiagent/GyroJAX.git
cd GyroJAX
pip install -e .
For GPU support, install JAX with CUDA:
pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
First Run: ITG Simulation (CBC Parameters)¶
The simplest simulation uses the Cyclone Base Case (CBC) parameters for ITG turbulence:
import jax
from gyrojax.simulation_fa import SimConfigFA, run_simulation_fa
cfg = SimConfigFA(
Npsi=16, Ntheta=32, Nalpha=32,
N_particles=50_000,
n_steps=200,
dt=0.05,
R0_over_LT=6.9, # CBC parameters
)
diags, state, phi, geom = run_simulation_fa(cfg, key=jax.random.PRNGKey(42))
print(f"Final phi_max: {diags[-1].phi_max:.3e}")
This runs ~200 steps at the CBC operating point (R/LT = 6.9, above the Dimits threshold of ~6.0), so you should observe growing ITG fluctuations.
Plotting Results¶
import matplotlib.pyplot as plt
import jax.numpy as jnp
phi_rms = jnp.array([d.phi_rms for d in diags])
plt.semilogy(phi_rms)
plt.xlabel("Step")
plt.ylabel("φ_rms")
plt.title("ITG linear growth (CBC)")
plt.show()
KBM (Electromagnetic) Simulation¶
To enable electromagnetic effects with finite plasma β:
cfg = SimConfigFA(
Npsi=16, Ntheta=32, Nalpha=32,
N_particles=100_000,
n_steps=300,
dt=0.05,
R0_over_LT=6.9,
beta=0.01, # plasma beta (electromagnetic coupling)
fused_rk4=True, # 3.78× speedup
)
diags, state, phi, geom = run_simulation_fa(cfg, key=jax.random.PRNGKey(0))
β_crit for KBM
KBM onset occurs around β ≈ 0.010–0.012 for CBC-like parameters. Below β_crit, ITG dominates; above it, KBM replaces ITG as the fastest-growing mode.
TEM Simulation (Drift-Kinetic Electrons)¶
To simulate trapped electron modes, enable the drift-kinetic electron model:
cfg = SimConfigFA(
Npsi=16, Ntheta=32, Nalpha=32,
N_particles=100_000,
n_steps=400,
dt=0.02,
R0_over_LT=3.0, # low ion gradient (below ITG threshold)
R0_over_LTe=9.0, # strong electron temperature gradient → TEM
R0_over_Ln=2.2,
electron_model='drift_kinetic',
subcycles_e=10, # subcycle electron push for stability
)
diags, state, phi, geom = run_simulation_fa(cfg, key=jax.random.PRNGKey(1))
TEM vs ITG
TEM is driven by the electron temperature gradient (R₀/LTe). Use R0_over_LTe ≥ 9.0
to see positive growth rates. At R₀/LTe = 5.0 the mode is stable.
Noise Control Options¶
For long nonlinear runs, consider enabling noise reduction:
cfg = SimConfigFA(
...
use_pullback=True, # periodic f₀ pullback (controls weight growth)
pullback_interval=50, # pullback every 50 steps
use_weight_spread=True, # GTC-style weight spreading
weight_spread_interval=10,
zonal_preserving_spread=True,
semi_implicit_weights=True, # CN weight update (unconditionally stable)
)
Running Benchmarks¶
pytest tests/ -q
Or run a specific benchmark:
python benchmarks/dimits_shift.py
python benchmarks/rosenbluth_hinton.py