Contributing¶
Contributions to GyroJAX are welcome! This guide covers how to set up the development environment, run tests, and submit pull requests.
Dev Setup¶
git clone https://github.com/runlaiagent/GyroJAX.git
cd GyroJAX
pip install -e ".[dev]"
Required dependencies: - JAX 0.4.x+ (with CUDA for GPU) - NumPy, SciPy, Matplotlib - pytest (for tests) - mkdocs-material (for docs)
Running Tests¶
pytest tests/ -q
Run a specific test:
pytest tests/test_simulation_fa.py -v
pytest tests/test_poisson.py -v
Run with GPU (default if available):
JAX_PLATFORM_NAME=gpu pytest tests/ -q
Run CPU-only (for CI/testing without GPU):
JAX_PLATFORM_NAME=cpu pytest tests/ -q
Adding Benchmarks¶
Benchmarks live in benchmarks/. A benchmark should:
- Define a
SimConfigFAwith the reference parameters - Run the simulation using
run_simulation_fa - Extract the relevant diagnostic (growth rate, threshold, etc.)
- Compare against the published reference value
- Print a PASS/FAIL summary
Example structure:
# benchmarks/my_benchmark.py
from gyrojax.simulation_fa import SimConfigFA, run_simulation_fa
import jax
def run():
cfg = SimConfigFA(
N_particles=50_000,
n_steps=300,
R0_over_LT=6.9,
# ...
)
diags, state, phi, geom = run_simulation_fa(cfg, key=jax.random.PRNGKey(0))
# Extract growth rate from log(phi_rms)
import jax.numpy as jnp
phi_rms = jnp.array([d.phi_rms for d in diags])
# ... fit growth rate ...
ref_value = 0.169 # GENE/GX reference
tol = 0.05
assert abs(gamma - ref_value) / ref_value < tol, f"Growth rate {gamma:.3f} outside tolerance"
print(f"PASS: γ = {gamma:.3f} (ref: {ref_value})")
if __name__ == "__main__":
run()
Code Style¶
- Functional style: prefer pure functions, avoid mutable state
- JAX idioms: use
jax.lax.scan,jax.vmap,jax.jitappropriately - Type hints: use for all public functions
- Docstrings: Google style, include parameter descriptions
- No raw loops over particle arrays: use JAX vectorized ops
Physics Changes¶
When changing physics (push equations, Poisson solver, gyroaveraging):
- Verify against at least one benchmark in
benchmarks/ - Add a unit test in
tests/checking the modified component - Update
docs/physics.mdif the model changes - Note the change in the PR description with the relevant equations
PR Guidelines¶
- Fork the repository and create a feature branch:
git checkout -b feature/my-feature - Write tests for any new functionality
- Run the full test suite before submitting:
pytest tests/ -q - Keep PRs focused — one feature or fix per PR
- Describe the physics in the PR description if adding/changing a physics model
- Reference issues if the PR closes one:
Closes #42
PRs are reviewed for: - Correctness (physics and numerics) - JAX/JIT compatibility (no Python-side loops over dynamic sizes) - Test coverage - Documentation
Building Docs Locally¶
pip install mkdocs-material
mkdocs serve
Then open http://localhost:8000 in your browser.
Contact¶
Open an issue on GitHub for questions, bugs, or feature requests.