Skip to content

Reduce JsonlDataset memory by using mmap array and use npy to store jsonl meta#1604

Open
jayhenry wants to merge 6 commits intoInternLM:mainfrom
jayhenry:json_npy
Open

Reduce JsonlDataset memory by using mmap array and use npy to store jsonl meta#1604
jayhenry wants to merge 6 commits intoInternLM:mainfrom
jayhenry:json_npy

Conversation

@jayhenry
Copy link
Collaborator

No description provided.

@jayhenry
Copy link
Collaborator Author

@claude review

fpath = os.path.join(dir_path, fname)
arr = np.load(fpath, mmap_mode="r" if mmap else None, allow_pickle=True)
if arr.dtype == object:
arr = np.load(fpath, allow_pickle=True)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Warning: allow_pickle=True on npy files from disk

np.load(..., allow_pickle=True) can execute arbitrary code during deserialization of object-dtype arrays. While these files are typically written by the same process or a trusted pipeline, if the cache directory is on a shared filesystem, a malicious actor could place a crafted .npy file there.

Consider: (1) only using allow_pickle=True when arr.dtype == object is confirmed (which you do on the retry), and removing it from the first np.load call, or (2) documenting that cache directories must be trusted.

@claude
Copy link

claude bot commented Mar 19, 2026

Claude: ## Summary

This PR reduces JsonlDataset memory by switching self.sampled from a Python list to numpy arrays (removing it as an instance attribute), replacing NPZ cache with per-key NPY directories for mmap-friendly loading, and adding cross-rank mmap sharing via /tmp npy files.

Issues

Critical

  • jsonl.py:82: Lost _has_chunk guard on max_length filtering. The original code explicitly skips max_length filtering when LongTextPretrainTokenizeFunction is used (not self._has_chunk). The new _filter_sampled_indices function always applies max_length filtering, which will incorrectly discard valid long-text samples that have been split into chunks.

Warning

  • custom_pack.py:209-211: Pack validation significantly weakened — sample_idx range checks and token range validation removed. token_end=0 sentinel resolution also removed.
  • jsonl.py:73: allow_pickle=True on first np.load call — security concern if cache dir is on shared filesystem.
  • jsonl.py:466: MD5-based tmp_dir can collide across jobs when tokenize_fn is not CachableTokenizeFunction.
  • test_jsonl_dataset.py:5: import parametrize — unusual third-party package instead of pytest.mark.parametrize.

Nit

  • jsonl.py:220: Commented-out code # return dist.local_rank() == 0 should be removed.

Verdict

REQUEST_CHANGES

- Introduced `enable_mmap_shared` parameter in `DatasetConfig` and `JsonlDataset` to allow shared memory mapping for improved memory efficiency.
- Updated smoke tests to validate the impact of shared memory on PSS and RSS metrics.
- Enhanced dataset building process to utilize mmap for loading data, reducing memory footprint during distributed training.
- Refactored test cases to include new functionality and ensure correctness of data loading and memory usage.
…ndly loading

- Replace save_dict_to_npz/load_dict_from_npz with save_dict_to_npy_dir/load_dict_from_npy_dir
  (each _meta key stored as a separate .npy file under jsonl_meta/ dir)
- load_dict_from_npy_dir uses mmap_mode='r' by default, enabling OS page sharing
- mmap shared-memory branch now uses save/load_dict_to/from_npy_dir for self._meta
  (stored under tmp_dir/meta/ subdir to separate from offsets/num_tokens)
- Refactor _set_meta_attrs as shared method called by both mmap and non-mmap paths
- Add tests: npy-dir round-trip, cache_dir consistency, cache_tag consistency

Made-with: Cursor
@jayhenry
Copy link
Collaborator Author

@claude review

Comment on lines +482 to +483
np.save(os.path.join(tmp_dir, "offsets.npy"), self.offsets)
if self.num_tokens is not None:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Warning — atexit race condition with mmap files.

atexit.register(shutil.rmtree, tmp_dir, True) is only registered by local rank 0. But all ranks mmap from this directory. If rank 0's process exits and its atexit handler fires before other ranks finish, the mmap backing files are deleted.

On Linux, open mmap file descriptors survive unlink, so this usually works — but it's fragile. If any rank hasn't opened the files yet when cleanup runs, it will fail. Consider:

  1. Using a dist.barrier() before cleanup, or
  2. Having each rank register its own cleanup, or
  3. Not auto-cleaning /tmp (it's a tmpdir, the OS handles it).

import time
import random

import tracemalloc
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Nit — import parametrize uses a third-party package.

This imports the parametrize package (pypi: parametrize) instead of using the standard pytest.mark.parametrize. The project already depends on pytest per CLAUDE.md guidelines.

Suggested change
import tracemalloc
import pytest

And change @parametrize.parametrize(...) to @pytest.mark.parametrize(...) accordingly.

result = {}
for fname in os.listdir(dir_path):
if not fname.endswith(".npy"):
continue
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude: Nit — save_dict_to_npy_dir silently accepts non-ndarray values.

The old save_dict_to_npz had explicit type checks (value must be np.ndarray, key must be str). The new save_dict_to_npy_dir passes anything to np.save without validation, which could silently produce pickle-based .npy files or unexpected behavior.

@claude
Copy link

claude bot commented Mar 19, 2026

Claude: ## Summary

This PR reduces JsonlDataset memory by switching self.sampled from a Python list to numpy arrays, replacing NPZ cache with per-key NPY directories for mmap-friendly loading, and adding cross-rank mmap sharing via /tmp npy files. It also removes several validation checks from CustomPackDataset.

Issues

Critical

  • custom_pack.py:208-211: token_end=0 sentinel resolution removed. The module docstring still documents token_end == 0 as "take to end of sample", but the resolution code is gone. __getitem__ will slice input_ids[t_start:0] → empty list, silently producing zero-length sequences.

Warning

  • jsonl.py:82: _has_chunk guard removed from max_length filtering, violating documented architecture contract.
  • config.py:51: enable_mmap_shared defaults to True — behavior change for all existing users on upgrade.
  • jsonl.py:482-483: atexit cleanup race — rank 0 may delete mmap backing files before other ranks open them.

Nit

  • test_jsonl_dataset.py:7: import parametrize — third-party package; prefer pytest.mark.parametrize.
  • jsonl.py:68: save_dict_to_npy_dir lost input type validation from old save_dict_to_npz.

Verdict

REQUEST_CHANGES

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant