Skip to content

4.0#402

Open
jsuarez5341 wants to merge 703 commits into
3.0from
4.0
Open

4.0#402
jsuarez5341 wants to merge 703 commits into
3.0from
4.0

Conversation

@jsuarez5341
Copy link
Copy Markdown
Contributor

This PR will not be merged. We are targeting EoY and 4.0 will just become master. Key goals:

  • Sweeps for all envs, largest ever public dataset of RL experiments
  • Constellation
  • Major perf enhancements

TBD: cpp/barracuda, final constellation features, xlstm, advantage calc tweaks

Infatoshi and others added 14 commits April 20, 2026 15:26
Rewrites craftax_spawn_mobs_native to strip JAX-isms that are pointless
on CPU:
 - bool[48][48] validity mask -> compact (int16, int16) coord list
   collected in one pass over the bounding box around the player
 - bounding-box scan: mobs can only spawn within MOB_DESPAWN_DISTANCE=14,
   so we only visit the up-to-27x27 window instead of the full 48x48 map
 - early return when can_spawn is already false from the mob-cap or
   probability roll, skipping the scan + choice
 - merged count_mobs3 + first_empty_mobs3 into a single loop
 - inlined the block-type and distance checks

Choice arithmetic uses the same FP expressions as baseline so the selected
cell is bitwise-identical for any given (valid_count, rng_key) pair. The
baseline quirk of writing type_id[level][slot] unconditionally even when
no mob spawns is preserved.

Phase timing (single-thread, random actions):
  craftax_spawn_mobs_native: 17.06 us -> 0.30 us (57x)
  full c_step:               29.6 us -> 12.3 us  (2.4x)

Verified bitwise-equal to the prior implementation over 1.28M paired
steps (128 envs x 10000 steps, random actions, reset exercised).
c_reset and the c_step auto-reset path now optionally memcpy from a
pre-generated pool of worlds instead of running generate_world each
episode. Pool size is a runtime kwarg (reset_pool_size) read by
my_init, default 1024 via config/ocean/craftax.ini. Set to 0 to
disable and regenerate every reset (required for strict per-seed
determinism in tests/craftax_parity.py).

Trade: at most reset_pool_size unique maps are seen per process. With
1024 and ~270-step random-action episodes, diversity is plentiful for
training. Memory cost: 1024 * sizeof(CraftaxState) ~= 267 MB once at
startup.

Two reset entry points are now distinguished:
 - craftax_reset_state_from_reset_key: direct (used by parity harness),
   always calls generate_state_from_world_key, pool-free for exact
   per-key determinism.
 - craftax_reset_state_on_done: hot-path used by c_step on terminal,
   consults the pool when enabled, falls through to generate_world
   otherwise. Pool index derived from reset_key.word[0].

tests/craftax_parity.py picks up raylib's include path since craftax.h
now pulls raylib.h (from the shared renderer).

Measurements (single-thread, random actions):
  worldgen:            2.69 ms -> 6.9 us memcpy (~390x)
  full c_step:         12.3 us -> 2.35 us (5.25x)
  training SPS:        450K -> 506K (+12%)
  1-thread sim SPS:    81K -> 425K (5.25x)
  16-thread sim SPS:   1.14M -> 5.53M (4.85x)
The five move_* helpers (melee/passive/ranged mobs + mob/player
projectiles) now return immediately when mask=false. JAX's branchless
"compute-then-mask" pattern is pointless on CPU: dead slots' output
never feeds observations, rewards, or mob_map, so skipping the body
and the RNG draws is semantically equivalent.

Defining CRAFTAX_JAX_PARITY at build time restores the branchless
slow path for bitwise replay against JAX (required by
tests/craftax_parity.py). Default build uses the early-out.

Also drops craftax_step_jax_index(player_level, NUM_LEVELS) clamps at
the top of each move_* -- state->player_level is maintained in
[0, NUM_LEVELS-1] by change_floor_native (explicit bounds checks) and
by the worldgen init. Six redundant clamps per step eliminated.

Measurements (single-thread, random actions, pool=1024):
  update_mobs phase:   1.392 us -> 0.285 us (4.88x)
  full c_step:         2.35 us -> 1.22 us
  1-thread sim SPS:    425K -> 819K (1.93x)
  16-thread sim SPS:   5.53M -> 10.04M (1.82x)
  training SPS:        506K -> 544K (+7%)

Parity test with CRAFTAX_JAX_PARITY defined passes 8 seeds * 1000
steps over 27 terminals. Without the flag, parity diverges at the
first mob death -- by design.
These 10 tests were written incrementally as each subsystem (noise,
threefry, worldgen, 7 step subsystems) was ported from JAX, to catch
divergence at each layer. Now that tests/craftax_parity.py passes
end-to-end against the JAX reference, they are redundant: any bug
they'd catch also breaks the integration parity test.

Dropping ~5400 LOC of scaffolding. Kept:
 - craftax_parity.py         (JAX<->C integration parity harness)
 - craftax_state_fixtures.py (state-flattening helpers used by parity)
 - craftax_parity_stress.py  (adversarial action sequences)
 - craftax_step_full_test.py (pytest wrapper -> parity.run)
The dashboard and CSV logger only need to surface a handful of
milestones along the tech/exploration curve, not every achievement.
The env still tracks all 67 internally for reward computation and
for the normalized 'perf' aggregate -- we just stop shipping every
one through the log Dict each episode flush.

Checkpoints chosen to span the learning curve:
  collect_wood         first resource (tier 1)
  make_wood_pickaxe    first tool
  make_stone_pickaxe   stone tier
  collect_iron         iron tier resource
  make_iron_pickaxe    iron tier tool (major milestone)
  collect_diamond      diamond tier resource
  enter_gnomish_mines  first dungeon (exploration)
  defeat_necromancer   endgame boss

Log Dict now carries 4 meta + 8 achievements + 1 n = 13 fields, well
under the stock create_dict(32) capacity. Releases the need for the
capacity bump in src/bindings* (reverted in the following commit).
- config/ocean/craftax.ini -> config/craftax.ini
- config/ocean/craftax_classic.ini -> config/craftax_classic.ini
- ocean/craftax/textures.bin -> resources/craftax/textures.bin
- scripts/craftax_convergence_bench.py -> tests/craftax_convergence_bench.py
- drop empty scripts/ directory
- pack_textures.py: write to resources/craftax/textures.bin
- craftax.h / craftax_classic.h: fopen textures from resources/craftax/
Used by the craftax parity harness to compile with
-DCRAFTAX_JAX_PARITY, which disables the update_mobs early-out so the
C env replays bitwise against JAX. Default training builds leave
EXTRA_CFLAGS empty and keep the ~2x sim-SPS early-out enabled.
Craftax Full: native C port + optimizations + renderer
@KTibow
Copy link
Copy Markdown
Contributor

KTibow commented Apr 27, 2026

Can probably be closed :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.