Related: [[HighJax docs]]
Naming conventions, style rules, and JAX-specific patterns used throughout the HighJax codebase.
Almost all JAX arrays use the pattern foo_by_bar_by_baz. Each by_something is an axis (or sometimes multiple axes). Axes go right-to-left: foo_by_bar_by_baz is a 2D array where axis 0 is baz, axis 1 is bar, and the value is foo.
Examples:
reward_by_e_by_t-- shape (n_ts, n_es), reward at each (timestep, episode)p_by_action_by_e-- shape (n_es, *action_shape), action probabilities
Multi-dimensional items like position or action occupy multiple axes. The number of axes that action occupies depends on the environment.
This applies to ALL arrays -- including intermediate variables, temporaries, results, diffs, distances. Even a simple subtraction result should be named diff_by_e_by_t, not diff.
When unpacking array shapes to get dimension sizes, use n_foos_per_bar:
foos: plural of what's being countedbar: the containing unit
Examples:
n_es_per_epoch, n_cells_per_e, n_tokens_per_vocabulary = logit_by_token_by_cell_by_e.shape
n_es_per_epoch, n_cells_per_e = token_by_cell_by_e.shapeWhen naming a variable that's an index number, use i_foo. However, these are exceptions that don't need the i_ prefix: epoch, t, e.
Only use shorthands that already exist in the codebase. Don't invent new ones. Established shorthands:
pfor probabilityefor episodetfor timestepftfor flat timestep (full flattened pool, in minibatch code)mtfor minibatch timestep (within one minibatch slice)tsfor timesteps (plural, in CLI args like--n-ts)esfor episodes (plural, in CLI args like--n-es)vfor value estimatevffor value functionkldfor KL divergence (never barekl)obsfor observation (matches gymnax API convention)nzfor normalized (prefix, e.g.nz_speed,nz_return,nz_advantage)thetafor model parameters (neural network param dicts)vitalfor alive-mask arrays (not post-crash)tendencyfor log-probability of chosen actionepiloguefor post-final-step values (e.g.epilogue_v_by_agent_by_e)lungefor action dimension in multi-discrete action spacesdeedfor a choice within a lungembfor minibatch (in axis names like_by_mt_by_mb)sweepfor one pass over all minibatches
Don't shorten position to pos, etc.
- Python 3.12+ required
- Single quotes everywhere, unless there's a quote-in-quote situation
- Maximum line length: 100 characters
- Type annotations using builtins (
list,tuple,dict), notList,Tuple,Dict from __future__ import annotationsat the top of every file- Import order:
__future__> stdlib > third-party > highjax - snake_case for variables/functions, PascalCase for classes
The highjax environment package has docstrings on public API functions (reset, step, etc.) since it serves as a library. The highjax_trainer package generally avoids function docstrings — the code should be self-explanatory. Add comments sparingly, only when the code is genuinely difficult to understand otherwise.
Many functions that process arrays need to be JIT-compiled. This means:
- No Python control flow on array values (use
jnp.whereinstead ofif) - Be careful with loops (use
jax.lax.scanorjax.vmapinstead) - Use
jax.Arrayand pytree-compatible data structures
HighJax uses @flax.struct.dataclass for most data classes. These are JAX-compatible (pytree-registered) frozen dataclasses. Fields can be marked pytree_node=False via flax.struct.field(pytree_node=False) to exclude from JAX tracing (e.g., config objects).
The gradient computation pipeline has a specific data flow with its own naming:
Ascender -> SweepMaster (flatten to _by_ft) -> Sweeper (shuffle to _by_mt_by_minibatch) -> Minibatcher (_by_mt per minibatch)
The Minibatcher computes the composite actor objective (PPO clipped surrogate + entropy) and produces gradients via jax.grad. The critic is updated separately.
- Golden tests (
test_golden_runs/): Deterministic training runs with exact expected values. When the training pipeline changes, these need regeneration. Each test defines its owntrain()function; run it, capture the new values, updategolden_data. - Unit tests: Everything else -- estimators, objectives, masking, freezing, trainer integration, etc.