[Refactor] Reduce memory usage in HardPackDataset via shared memory#1602
[Refactor] Reduce memory usage in HardPackDataset via shared memory#1602HAOCHENYE wants to merge 1 commit intoInternLM:mainfrom
Conversation
HAOCHENYE
commented
Mar 19, 2026
- Add SharedPoolExecutor to bind fn + ndarray kwargs once per worker via POSIX shared memory, eliminating per-task serialization overhead
- Vectorize _hard_pack_chunk_core with batched searchsorted
- Replace incremental _update pattern with collect-all + _merge_pack_infos
- Save pack_infos to mmap-backed files so all ranks on a node share the same physical pages instead of each rank holding its own copy
- Add is_local_rank0() utility to xtuner.v1.utils
- Add unit tests for SharedPoolExecutor
- Add SharedPoolExecutor to bind fn + ndarray kwargs once per worker via POSIX shared memory, eliminating per-task serialization overhead - Vectorize _hard_pack_chunk_core with batched searchsorted - Replace incremental _update pattern with collect-all + _merge_pack_infos - Save pack_infos to mmap-backed files so all ranks on a node share the same physical pages instead of each rank holding its own copy - Add is_local_rank0() utility to xtuner.v1.utils - Add unit tests for SharedPoolExecutor
780e296 to
e106651
Compare
|
@claude review |
| pack_infos_list = [] | ||
| lengths_arr = (cu_arr[1:] - cu_arr[:-1]).astype(np.int64) | ||
| all_results: list[dict[str, np.ndarray]] = [] | ||
|
|
There was a problem hiding this comment.
Claude: Bug — non-deterministic shuffle breaks seeded reproducibility
The old code used self.random.shuffle(inds) where self.random is a random.Random(seed) instance, producing deterministic results for a given seed. This new code uses np.random.shuffle(inds) which draws from the global numpy random state — it is not seeded by self.seed.
This means:
- Two runs with the same
seedwill produce different pack_infos (non-reproducible). - The content-addressed cache key in
_build_pack_infosincludes the seed, so the cache will serve the result from whichever run happened first, masking the non-determinism on subsequent runs — but the first run is always non-deterministic.
Suggested fix:
rng = np.random.RandomState(self.seed)
inds = np.arange(len(dataset))
rng.shuffle(inds)Or pass the Random instance / seed into _compute_pack_infos and use it there.
| "start_offset": np.concatenate([r["start_offset"] for r in infos]), | ||
| "end_offset": np.concatenate([r["end_offset"] for r in infos]), | ||
| "longest": np.concatenate([r["longest"] for r in infos]), |
There was a problem hiding this comment.
Claude: Warning — process group created but never destroyed (resource leak)
dist.new_group() allocates a new NCCL communicator. After the barrier, pack_pg is never destroyed. Over time (e.g., if multiple HardPackDataset instances are created), this leaks NCCL communicators.
| "start_offset": np.concatenate([r["start_offset"] for r in infos]), | |
| "end_offset": np.concatenate([r["end_offset"] for r in infos]), | |
| "longest": np.concatenate([r["longest"] for r in infos]), | |
| pack_pg = dist.new_group(timeout=datetime.timedelta(seconds=7200)) | |
| dist.barrier(group=pack_pg) | |
| dist.destroy_process_group(pack_pg) |
| dist.barrier() | ||
|
|
There was a problem hiding this comment.
Claude: Warning — barrier uses default process group, may timeout on large datasets
_build_pack_infos creates a dedicated process group with a 7200s timeout specifically because pack computation can be slow. But ndarray_to_mmap uses the default process group for its barrier. If rank 0 is slow writing a large array, non-rank-0 processes will block here on the default group's barrier — which has the default NCCL timeout (typically 30 minutes or less).
Consider accepting an optional group parameter, or creating a similarly long-timeout group here, to stay consistent with the caller's intent.
| p.join() | ||
| self._result_queue.put(_StopSignal()) | ||
| self._drain_thread.join() | ||
| for shm in self._shm_handles: | ||
| shm.close() |
There was a problem hiding this comment.
Claude: Nit — shutdown can hang indefinitely if a worker is stuck
p.join() has no timeout. If a worker process deadlocks or crashes without consuming its _StopSignal, this will block forever. Consider adding a timeout and force-killing unresponsive workers:
for p in self._workers:
p.join(timeout=30)
if p.is_alive():
p.kill()
p.join()| cu_parts: list[np.ndarray] = [] | ||
| offset = np.int64(0) | ||
| for r in infos: | ||
| cu_parts.append(r["indices_cu_len"] + offset) |
There was a problem hiding this comment.
Claude: Warning — writes to pack_dir are not atomic
Unlike ndarray_to_mmap which correctly uses temp-file + atomic rename, _save_pack_infos writes .npy files directly into the final directory. If the process is killed mid-write (OOM, preemption), other ranks will find a partially-written directory and _load_pack_infos_from_mmap will silently load truncated/corrupt files.
Consider writing to a temp directory first, then atomically renaming:
tmp_dir = Path(tempfile.mkdtemp(dir=pack_dir.parent))
for field in self._PACK_INFO_FIELDS:
arr = pack_infos.get(field)
if arr is not None:
np.save(str(tmp_dir / f"{field}.npy"), arr)
os.rename(str(tmp_dir), str(pack_dir))|
Claude: ## Review Summary Nice refactor — the shared-memory approach for large ndarray kwargs and mmap-backed pack_infos should significantly reduce per-rank memory in distributed settings. The vectorized IssuesCritical
Warning
Nit
VerdictREQUEST_CHANGES |