Skip to content

[Refactor] Reduce memory usage in HardPackDataset via shared memory#1602

Open
HAOCHENYE wants to merge 1 commit intoInternLM:mainfrom
HAOCHENYE:yehc/speed-up-packing
Open

[Refactor] Reduce memory usage in HardPackDataset via shared memory#1602
HAOCHENYE wants to merge 1 commit intoInternLM:mainfrom
HAOCHENYE:yehc/speed-up-packing

Conversation

@HAOCHENYE
Copy link
Collaborator

  • Add SharedPoolExecutor to bind fn + ndarray kwargs once per worker via POSIX shared memory, eliminating per-task serialization overhead
  • Vectorize _hard_pack_chunk_core with batched searchsorted
  • Replace incremental _update pattern with collect-all + _merge_pack_infos
  • Save pack_infos to mmap-backed files so all ranks on a node share the same physical pages instead of each rank holding its own copy
  • Add is_local_rank0() utility to xtuner.v1.utils
  • Add unit tests for SharedPoolExecutor

- Add SharedPoolExecutor to bind fn + ndarray kwargs once per worker
  via POSIX shared memory, eliminating per-task serialization overhead
- Vectorize _hard_pack_chunk_core with batched searchsorted
- Replace incremental _update pattern with collect-all + _merge_pack_infos
- Save pack_infos to mmap-backed files so all ranks on a node share the
  same physical pages instead of each rank holding its own copy
- Add is_local_rank0() utility to xtuner.v1.utils
- Add unit tests for SharedPoolExecutor
@HAOCHENYE HAOCHENYE force-pushed the yehc/speed-up-packing branch from 780e296 to e106651 Compare March 19, 2026 07:53
@HAOCHENYE
Copy link
Collaborator Author

@claude review

pack_infos_list = []
lengths_arr = (cu_arr[1:] - cu_arr[:-1]).astype(np.int64)
all_results: list[dict[str, np.ndarray]] = []

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Bug — non-deterministic shuffle breaks seeded reproducibility

The old code used self.random.shuffle(inds) where self.random is a random.Random(seed) instance, producing deterministic results for a given seed. This new code uses np.random.shuffle(inds) which draws from the global numpy random state — it is not seeded by self.seed.

This means:

  1. Two runs with the same seed will produce different pack_infos (non-reproducible).
  2. The content-addressed cache key in _build_pack_infos includes the seed, so the cache will serve the result from whichever run happened first, masking the non-determinism on subsequent runs — but the first run is always non-deterministic.

Suggested fix:

rng = np.random.RandomState(self.seed)
inds = np.arange(len(dataset))
rng.shuffle(inds)

Or pass the Random instance / seed into _compute_pack_infos and use it there.

Comment on lines +414 to +416
"start_offset": np.concatenate([r["start_offset"] for r in infos]),
"end_offset": np.concatenate([r["end_offset"] for r in infos]),
"longest": np.concatenate([r["longest"] for r in infos]),
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Warning — process group created but never destroyed (resource leak)

dist.new_group() allocates a new NCCL communicator. After the barrier, pack_pg is never destroyed. Over time (e.g., if multiple HardPackDataset instances are created), this leaks NCCL communicators.

Suggested change
"start_offset": np.concatenate([r["start_offset"] for r in infos]),
"end_offset": np.concatenate([r["end_offset"] for r in infos]),
"longest": np.concatenate([r["longest"] for r in infos]),
pack_pg = dist.new_group(timeout=datetime.timedelta(seconds=7200))
dist.barrier(group=pack_pg)
dist.destroy_process_group(pack_pg)

Comment on lines +71 to +72
dist.barrier()

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Warning — barrier uses default process group, may timeout on large datasets

_build_pack_infos creates a dedicated process group with a 7200s timeout specifically because pack computation can be slow. But ndarray_to_mmap uses the default process group for its barrier. If rank 0 is slow writing a large array, non-rank-0 processes will block here on the default group's barrier — which has the default NCCL timeout (typically 30 minutes or less).

Consider accepting an optional group parameter, or creating a similarly long-timeout group here, to stay consistent with the caller's intent.

Comment on lines +191 to +195
p.join()
self._result_queue.put(_StopSignal())
self._drain_thread.join()
for shm in self._shm_handles:
shm.close()
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Nit — shutdown can hang indefinitely if a worker is stuck

p.join() has no timeout. If a worker process deadlocks or crashes without consuming its _StopSignal, this will block forever. Consider adding a timeout and force-killing unresponsive workers:

for p in self._workers:
    p.join(timeout=30)
    if p.is_alive():
        p.kill()
        p.join()

cu_parts: list[np.ndarray] = []
offset = np.int64(0)
for r in infos:
cu_parts.append(r["indices_cu_len"] + offset)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Warning — writes to pack_dir are not atomic

Unlike ndarray_to_mmap which correctly uses temp-file + atomic rename, _save_pack_infos writes .npy files directly into the final directory. If the process is killed mid-write (OOM, preemption), other ranks will find a partially-written directory and _load_pack_infos_from_mmap will silently load truncated/corrupt files.

Consider writing to a temp directory first, then atomically renaming:

tmp_dir = Path(tempfile.mkdtemp(dir=pack_dir.parent))
for field in self._PACK_INFO_FIELDS:
    arr = pack_infos.get(field)
    if arr is not None:
        np.save(str(tmp_dir / f"{field}.npy"), arr)
os.rename(str(tmp_dir), str(pack_dir))

@claude
Copy link

claude bot commented Mar 19, 2026

Claude: ## Review Summary

Nice refactor — the shared-memory approach for large ndarray kwargs and mmap-backed pack_infos should significantly reduce per-rank memory in distributed settings. The vectorized _hard_pack_chunk_core and SharedPoolExecutor are well-structured.

Issues

Critical

  • Non-deterministic shuffle (packing.py:443): _compute_pack_infos uses np.random.shuffle (global state) instead of the seeded self.random instance. This breaks reproducibility — two runs with the same seed produce different packings.

Warning

  • Process group leak (packing.py:414-416): dist.new_group() is never destroyed after the barrier.
  • Non-atomic pack_infos write (packing.py:408): _save_pack_infos writes directly to the final directory without atomic rename, risking corrupt cache on crash.
  • Default-group barrier timeout (utils.py:71-72): ndarray_to_mmap barriers use the default process group timeout, inconsistent with the 7200s timeout in _build_pack_infos.

Nit

  • No join timeout in SharedPoolExecutor.shutdown (executor.py:191-195): p.join() without timeout can hang indefinitely.

Verdict

REQUEST_CHANGES

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant