jax-development
JAX Development
Use this skill for substantial JAX work. The agent should behave like a strong JAX reviewer and performance engineer: preserve functional semantics, choose the right transformations, explain the trace/compile/runtime split clearly, and avoid making performance claims that were not measured.
This version is designed to be unusually agent-friendly. It does not just bundle references; it gives the agent an operating workflow, decision matrices, a code-review rubric, and scripts that help verify environment, lowering, recompilation risk, and benchmark claims.
Core promise
When this skill is active, the default standard is:
- produce runnable JAX code, not generic advice
- explain why the change works in JAX terms
- call out likely sharp bits even if the user did not ask
- verify claims with the bundled scripts when possible
- separate compile-time, run-time, transfer, and sharding issues instead of mixing them together
When this skill should own the task
Use this skill when the difficult part of the request is any of the following:
- translating NumPy, SciPy, TensorFlow, or PyTorch code into idiomatic JAX
- fixing tracer, control-flow, PRNG, shape, dtype, or side-effect bugs
- choosing between
jit,vmap,scan,fori_loop,while_loop,cond,grad,jacrev,jacfwd,remat,shard_map, or export - removing recompiles, host-device round trips, Python overhead, or dishonest benchmarking
- reasoning about
jax.Array, meshes,PartitionSpec,NamedSharding, explicit sharding,pmapmigration, multi-host semantics, or collectives - using
jax.debug.print,checkify,make_jaxpr, lowering, compiler IR, profiler traces, or memory profiling - using custom derivatives, export, AOT lowering, custom partitioning, Pallas, or the JAX source tree
Compose this skill with framework-specific skills when needed, but let this one own the JAX-specific reasoning.
Do not over-apply the skill
Do not force JAX when the real problem is one of these instead:
- pure NumPy optimisation where JAX is explicitly out of scope
- generic CUDA, Triton, NCCL, or driver debugging with no meaningful JAX component
- framework-only design questions whose hard part is not JAX
- irregular dynamic object-heavy Python where the right answer is probably to keep the hot path outside JAX
When in doubt, ask: “Is the root of the problem tracing, transformations, array semantics, compilation, sharding, or the JAX runtime?” If yes, use this skill.
First-response workflow
1. Classify the task
Put the request into one or more lanes immediately:
- code design or porting
- debugging or correctness
- performance or compilation
- sharding or distributed execution
- advanced extension points
- JAX repo navigation or source-level questions
Then open the matching reference file:
references/EXPERT-WORKFLOW.mdfor the overall workflowreferences/MENTAL-MODEL.mdfor tracing and staging semanticsreferences/TRANSFORM-DECISION-MATRIX.mdfor choosing primitivesreferences/PORTING-PATTERNS.mdfor NumPy or PyTorch rewritesreferences/CODE-REVIEW-RUBRIC.mdfor self-review before replyingreferences/DEBUGGING-TRIAGE.mdfor error diagnosisreferences/PERFORMANCE-PLAYBOOK.mdfor speed, memory, and compile-time workreferences/SHARDING-PLAYBOOK.mdfor distributed and multi-device designreferences/ADVANCED-EXTENSIONS.mdfor custom autodiff, export, Pallas, FFI, and internalsreferences/REPO-MAP.mdfor local source-tree navigationreferences/SOURCES.mdfor provenance and maintenance notes
2. Inspect before guessing
If the problem could be environment-, backend-, or project-specific, inspect first.
Environment:
python3 scripts/jax_env_report.py --format json
Static project scan:
python3 scripts/jax_project_scan.py PATH --format json
Benchmark a callable honestly:
python3 scripts/jax_benchmark_harness.py --help
Inspect jaxpr, lowering, and IR:
python3 scripts/jax_compile_probe.py --help
Check likely recompile behaviour across cases:
python3 scripts/jax_recompile_explorer.py --help
Search a local JAX checkout:
python3 scripts/jax_repo_locator.py --help
3. Reduce to a minimal reproducer
Prefer the smallest function that still exhibits the behaviour. JAX problems get much easier once shapes, dtypes, batching axes, randomness, and transformation boundaries are explicit.
4. Choose the least powerful mechanism that solves the problem
Default ordering:
- pure eager
jax.numpyfirst - then
jitorvalue_and_grad - then
vmaporscan - then explicit sharding
- then
shard_map - then custom derivative, export, custom partitioning, or Pallas
- then FFI or JAX internals
Escalate only with evidence.
5. End with a high-signal answer
Unless the user asked for something else, the reply should end with:
- diagnosis or design choice
- corrected code or patch
- why it works in JAX terms
- how to verify it
- remaining risks, backend caveats, or performance unknowns
Expert operating rules
- Treat JAX functions as pure. Inputs in, outputs out. Hidden mutation, global state, or implicit randomness are usually design bugs once transforms enter the picture.
- Make randomness explicit. Thread keys through the program, split once per consumer, and return updated keys when state continues.
- Keep the hot path in JAX space. Host conversion inside transformed code is almost always a bug or a sync point.
- Separate static and dynamic values. Shapes, dtypes, Python objects, and some configuration values influence tracing and compilation.
- Use structured control flow. If a branch or loop depends on array values, use JAX control-flow primitives instead of Python.
- Benchmark honestly. Warm up, block, and distinguish transfer cost, compile cost, and steady-state execution.
- Optimise after evidence. Use scans, compile probes, profiler traces, or lowering inspection before proposing deep rewrites.
- Prefer current JAX idioms. Typed keys,
jax.Array, and modern sharding APIs are the default unless the codebase is intentionally legacy. - Think globally for sharding first. Start with global-view code and explicit placement before dropping to per-device manual code.
- Never bluff backend-specific behaviour. CPU, GPU, TPU, and multi-host runs differ materially. Say what was verified and what was inferred.
Default red flags to proactively check
Always scan for these, even if the user did not mention them:
np.asarray,.item(),.tolist(),jax.device_get, or printing arrays in a hot path- Python
if,for, orwhileinside transformed code - shape construction or indexing based on traced values
- global or reused PRNG keys
- repeated creation of jitted callables inside loops
- changing shapes, dtypes, or static arguments causing compile storms
- very large Python loops that should be
scanorfori_loop pmapcode that may be better expressed with modern sharding APIs- unexplained precision assumptions or implicit
x64expectations - replicated-versus-sharded confusion in distributed code
Available scripts
scripts/jax_env_report.py— report versions, backend, devices, config, env vars, and an optional smoke test.scripts/jax_project_scan.py— AST-based scan for common JAX sharp bits and migration targets.scripts/jax_benchmark_harness.py— benchmark a callable with warm-up, blocking, optionaljit, and optional donation.scripts/jax_compile_probe.py— inspecteval_shape, jaxpr, lowering, and compiler IR; optionally write artefacts to disk.scripts/jax_recompile_explorer.py— run several input cases through a jitted function and flag likely recompiles or signature drift.scripts/jax_repo_locator.py— search a local JAX checkout for relevant docs, tests, or source files by topic.
All scripts are non-interactive, support --help, and default to structured JSON output.
Available assets
assets/mre_template.py— minimal reproducible example templateassets/training_step_template.py— idiomatic compiled training step with explicit key plumbingassets/scan_template.py— carry-state loop usinglax.scanassets/sharding_template.py— mesh plusNamedShardingstarterassets/shard_map_template.py— manual SPMD starter usingjax.shard_mapassets/benchmark_template.py— honest timing pattern with warm-up and blockingassets/profile_template.py— trace and memory-profile starterassets/checkify_template.py— runtime checks that survivejitassets/custom_vjp_template.py— custom reverse-mode rule starterassets/export_template.py— export and serialisation starterassets/pallas_kernel_skeleton.py— kernel-level starting pointassets/issue_report_template.md— compact bug report / investigation template
Output quality bar
Before sending a final answer, mentally run the code or design through references/CODE-REVIEW-RUBRIC.md. The answer should usually satisfy all of the following:
- runnable or patch-ready code
- correct transformation and sharding semantics
- explicit discussion of compile and runtime consequences
- no accidental host round trips in the claimed hot path
- no hidden PRNG or state bugs
- an honest verification method
If the task is exploratory research code
Prefer a staged plan:
- get a correct eager version in
jax.numpy - add tests or invariants
- add transformations one at a time
- benchmark and profile
- only then attempt aggressive sharding or kernel work
This workflow beats premature jit/pmap/Pallas every time.
Skill maintenance
When updating this skill, refresh the JAX facts most likely to drift:
- installation guidance
- sharding APIs and
pmapmigration status - randomness recommendations
- profiler and memory-tooling guidance
- export / AOT APIs
- Pallas and custom extension interfaces