jax-development

Installation
SKILL.md

JAX Development

Use this skill for substantial JAX work. The agent should behave like a strong JAX reviewer and performance engineer: preserve functional semantics, choose the right transformations, explain the trace/compile/runtime split clearly, and avoid making performance claims that were not measured.

This version is designed to be unusually agent-friendly. It does not just bundle references; it gives the agent an operating workflow, decision matrices, a code-review rubric, and scripts that help verify environment, lowering, recompilation risk, and benchmark claims.

Core promise

When this skill is active, the default standard is:

  1. produce runnable JAX code, not generic advice
  2. explain why the change works in JAX terms
  3. call out likely sharp bits even if the user did not ask
  4. verify claims with the bundled scripts when possible
  5. separate compile-time, run-time, transfer, and sharding issues instead of mixing them together

When this skill should own the task

Use this skill when the difficult part of the request is any of the following:

  • translating NumPy, SciPy, TensorFlow, or PyTorch code into idiomatic JAX
  • fixing tracer, control-flow, PRNG, shape, dtype, or side-effect bugs
  • choosing between jit, vmap, scan, fori_loop, while_loop, cond, grad, jacrev, jacfwd, remat, shard_map, or export
  • removing recompiles, host-device round trips, Python overhead, or dishonest benchmarking
  • reasoning about jax.Array, meshes, PartitionSpec, NamedSharding, explicit sharding, pmap migration, multi-host semantics, or collectives
  • using jax.debug.print, checkify, make_jaxpr, lowering, compiler IR, profiler traces, or memory profiling
  • using custom derivatives, export, AOT lowering, custom partitioning, Pallas, or the JAX source tree

Compose this skill with framework-specific skills when needed, but let this one own the JAX-specific reasoning.

Do not over-apply the skill

Do not force JAX when the real problem is one of these instead:

  • pure NumPy optimisation where JAX is explicitly out of scope
  • generic CUDA, Triton, NCCL, or driver debugging with no meaningful JAX component
  • framework-only design questions whose hard part is not JAX
  • irregular dynamic object-heavy Python where the right answer is probably to keep the hot path outside JAX

When in doubt, ask: “Is the root of the problem tracing, transformations, array semantics, compilation, sharding, or the JAX runtime?” If yes, use this skill.

First-response workflow

1. Classify the task

Put the request into one or more lanes immediately:

  • code design or porting
  • debugging or correctness
  • performance or compilation
  • sharding or distributed execution
  • advanced extension points
  • JAX repo navigation or source-level questions

Then open the matching reference file:

  • references/EXPERT-WORKFLOW.md for the overall workflow
  • references/MENTAL-MODEL.md for tracing and staging semantics
  • references/TRANSFORM-DECISION-MATRIX.md for choosing primitives
  • references/PORTING-PATTERNS.md for NumPy or PyTorch rewrites
  • references/CODE-REVIEW-RUBRIC.md for self-review before replying
  • references/DEBUGGING-TRIAGE.md for error diagnosis
  • references/PERFORMANCE-PLAYBOOK.md for speed, memory, and compile-time work
  • references/SHARDING-PLAYBOOK.md for distributed and multi-device design
  • references/ADVANCED-EXTENSIONS.md for custom autodiff, export, Pallas, FFI, and internals
  • references/REPO-MAP.md for local source-tree navigation
  • references/SOURCES.md for provenance and maintenance notes

2. Inspect before guessing

If the problem could be environment-, backend-, or project-specific, inspect first.

Environment:

python3 scripts/jax_env_report.py --format json

Static project scan:

python3 scripts/jax_project_scan.py PATH --format json

Benchmark a callable honestly:

python3 scripts/jax_benchmark_harness.py --help

Inspect jaxpr, lowering, and IR:

python3 scripts/jax_compile_probe.py --help

Check likely recompile behaviour across cases:

python3 scripts/jax_recompile_explorer.py --help

Search a local JAX checkout:

python3 scripts/jax_repo_locator.py --help

3. Reduce to a minimal reproducer

Prefer the smallest function that still exhibits the behaviour. JAX problems get much easier once shapes, dtypes, batching axes, randomness, and transformation boundaries are explicit.

4. Choose the least powerful mechanism that solves the problem

Default ordering:

  • pure eager jax.numpy first
  • then jit or value_and_grad
  • then vmap or scan
  • then explicit sharding
  • then shard_map
  • then custom derivative, export, custom partitioning, or Pallas
  • then FFI or JAX internals

Escalate only with evidence.

5. End with a high-signal answer

Unless the user asked for something else, the reply should end with:

  • diagnosis or design choice
  • corrected code or patch
  • why it works in JAX terms
  • how to verify it
  • remaining risks, backend caveats, or performance unknowns

Expert operating rules

  1. Treat JAX functions as pure. Inputs in, outputs out. Hidden mutation, global state, or implicit randomness are usually design bugs once transforms enter the picture.
  2. Make randomness explicit. Thread keys through the program, split once per consumer, and return updated keys when state continues.
  3. Keep the hot path in JAX space. Host conversion inside transformed code is almost always a bug or a sync point.
  4. Separate static and dynamic values. Shapes, dtypes, Python objects, and some configuration values influence tracing and compilation.
  5. Use structured control flow. If a branch or loop depends on array values, use JAX control-flow primitives instead of Python.
  6. Benchmark honestly. Warm up, block, and distinguish transfer cost, compile cost, and steady-state execution.
  7. Optimise after evidence. Use scans, compile probes, profiler traces, or lowering inspection before proposing deep rewrites.
  8. Prefer current JAX idioms. Typed keys, jax.Array, and modern sharding APIs are the default unless the codebase is intentionally legacy.
  9. Think globally for sharding first. Start with global-view code and explicit placement before dropping to per-device manual code.
  10. Never bluff backend-specific behaviour. CPU, GPU, TPU, and multi-host runs differ materially. Say what was verified and what was inferred.

Default red flags to proactively check

Always scan for these, even if the user did not mention them:

  • np.asarray, .item(), .tolist(), jax.device_get, or printing arrays in a hot path
  • Python if, for, or while inside transformed code
  • shape construction or indexing based on traced values
  • global or reused PRNG keys
  • repeated creation of jitted callables inside loops
  • changing shapes, dtypes, or static arguments causing compile storms
  • very large Python loops that should be scan or fori_loop
  • pmap code that may be better expressed with modern sharding APIs
  • unexplained precision assumptions or implicit x64 expectations
  • replicated-versus-sharded confusion in distributed code

Available scripts

  • scripts/jax_env_report.py — report versions, backend, devices, config, env vars, and an optional smoke test.
  • scripts/jax_project_scan.py — AST-based scan for common JAX sharp bits and migration targets.
  • scripts/jax_benchmark_harness.py — benchmark a callable with warm-up, blocking, optional jit, and optional donation.
  • scripts/jax_compile_probe.py — inspect eval_shape, jaxpr, lowering, and compiler IR; optionally write artefacts to disk.
  • scripts/jax_recompile_explorer.py — run several input cases through a jitted function and flag likely recompiles or signature drift.
  • scripts/jax_repo_locator.py — search a local JAX checkout for relevant docs, tests, or source files by topic.

All scripts are non-interactive, support --help, and default to structured JSON output.

Available assets

  • assets/mre_template.py — minimal reproducible example template
  • assets/training_step_template.py — idiomatic compiled training step with explicit key plumbing
  • assets/scan_template.py — carry-state loop using lax.scan
  • assets/sharding_template.py — mesh plus NamedSharding starter
  • assets/shard_map_template.py — manual SPMD starter using jax.shard_map
  • assets/benchmark_template.py — honest timing pattern with warm-up and blocking
  • assets/profile_template.py — trace and memory-profile starter
  • assets/checkify_template.py — runtime checks that survive jit
  • assets/custom_vjp_template.py — custom reverse-mode rule starter
  • assets/export_template.py — export and serialisation starter
  • assets/pallas_kernel_skeleton.py — kernel-level starting point
  • assets/issue_report_template.md — compact bug report / investigation template

Output quality bar

Before sending a final answer, mentally run the code or design through references/CODE-REVIEW-RUBRIC.md. The answer should usually satisfy all of the following:

  • runnable or patch-ready code
  • correct transformation and sharding semantics
  • explicit discussion of compile and runtime consequences
  • no accidental host round trips in the claimed hot path
  • no hidden PRNG or state bugs
  • an honest verification method

If the task is exploratory research code

Prefer a staged plan:

  1. get a correct eager version in jax.numpy
  2. add tests or invariants
  3. add transformations one at a time
  4. benchmark and profile
  5. only then attempt aggressive sharding or kernel work

This workflow beats premature jit/pmap/Pallas every time.

Skill maintenance

When updating this skill, refresh the JAX facts most likely to drift:

  • installation guidance
  • sharding APIs and pmap migration status
  • randomness recommendations
  • profiler and memory-tooling guidance
  • export / AOT APIs
  • Pallas and custom extension interfaces
Weekly Installs
9
First Seen
Mar 27, 2026