-
Notifications
You must be signed in to change notification settings - Fork 222
[don't merge] split train and val dataset in preference dataset #1763
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Yuki Huang <yukih@nvidia.com>
Signed-off-by: Yuki Huang <yukih@nvidia.com>
Signed-off-by: Yuki Huang <yukih@nvidia.com>
Signed-off-by: Yuki Huang <yukih@nvidia.com>
Signed-off-by: Yuki Huang <yukih@nvidia.com>
Signed-off-by: Yuki Huang <yukih@nvidia.com>
Signed-off-by: Yuki Huang <yukih@nvidia.com>
Signed-off-by: Yuki Huang <yukih@nvidia.com>
Signed-off-by: Yuki Huang <yukih@nvidia.com>
Signed-off-by: Yuki Huang <yukih@nvidia.com>
Signed-off-by: Yuki Huang <yukih@nvidia.com>
Signed-off-by: Yuki Huang <yukih@nvidia.com>
Signed-off-by: Yuki Huang <yukih@nvidia.com>
Signed-off-by: Yuki Huang <yukih@nvidia.com>
Signed-off-by: Yuki Huang <yukih@nvidia.com>
Signed-off-by: Yuki Huang <yukih@nvidia.com>
Signed-off-by: Rayen <ruit@nvidia.com>
Signed-off-by: Yuki Huang <yukih@nvidia.com>
Signed-off-by: Yuki Huang <yukih@nvidia.com>
Signed-off-by: Yuki Huang <yukih@nvidia.com>
Signed-off-by: Yuki Huang <yukih@nvidia.com>
Signed-off-by: Yuki Huang <yukih@nvidia.com>
Signed-off-by: Yuki Huang <yukih@nvidia.com>
Signed-off-by: Yuki Huang <yukih@nvidia.com>
Signed-off-by: Yuki Huang <yukih@nvidia.com>
Signed-off-by: Yuki Huang <yukih@nvidia.com>
Signed-off-by: Yuki Huang <yukih@nvidia.com>
0923975 to
2fb1777
Compare
2fb1777 to
6086b51
Compare
6086b51 to
994a15f
Compare
📝 WalkthroughWalkthroughThis PR refactors the dataset configuration and data loading pipeline across the NeMo RL framework. It introduces a hierarchical config structure with separate Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested reviewers
🚥 Pre-merge checks | ✅ 4 | ❌ 2❌ Failed checks (2 warnings)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 17
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (9)
nemo_rl/data/datasets/response_datasets/oasst.py (1)
27-35: Minor typo in docstring."converstaions" should be "conversations".
📝 Proposed fix
def parse_conversations(tree_obj, first: bool = False): - """Recusive function that returns all the sub converstaions in a list starting from node tree_obj. + """Recursive function that returns all the sub conversations in a list starting from node tree_obj. Args: tree_obj (obj): current conversation node Returns: - a list of sub conversation threads including the current conversation node + A list of sub conversation threads including the current conversation node. """nemo_rl/data/datasets/response_datasets/squad.py (1)
1-1: Update NVIDIA copyright year to 2026.The header year should match the current year for non-test Python files.
✅ Proposed fix
-# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.nemo_rl/data/datasets/response_datasets/tulu3.py (1)
1-1: Update copyright year to 2026.The header still references 2025; it should match the current year.
🔧 Proposed fix
-# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.As per coding guidelines, update the copyright header year.
nemo_rl/data/datasets/preference_datasets/binary_preference_dataset.py (1)
20-39: Update the docstring to match the new formatting path.
Line 24 still referencesto_preference_data_format, which is no longer present and can mislead users.📝 Suggested docstring update
- It will be converted to the format of PreferenceDataset through the `to_preference_data_format` function. + Each sample is converted to the preference format via `format_data`.examples/run_sft.py (1)
1-1: Update NVIDIA header year to 2026.The header still references 2025.
✏️ Proposed fix
-# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.As per coding guidelines, please keep the header year current.
nemo_rl/data/datasets/response_datasets/clevr.py (1)
1-1: Update NVIDIA header year to 2026.The header still references 2025.
✏️ Proposed fix
-## Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +## Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.As per coding guidelines, please keep the header year current.
nemo_rl/data/datasets/response_datasets/__init__.py (1)
1-1: Update NVIDIA header year to 2026.The header still references 2025.
✏️ Proposed fix
-# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.As per coding guidelines, please keep the header year current.
examples/run_grpo_rm.py (1)
228-230: Training environments are not shut down.The cleanup loop only shuts down environments in
val_task_to_env, but training environments intask_to_envare not explicitly shut down. If training and validation use different environments, the training environments will not be cleaned up.🔧 Proposed fix to shutdown all environments
for task_name in val_task_to_env.keys(): env = val_task_to_env[task_name] env.shutdown.remote() + + # Shutdown training environments (skip if already shutdown via val_task_to_env) + for task_name, env in task_to_env.items(): + if task_name not in val_task_to_env: + env.shutdown.remote()docs/guides/sft.md (1)
33-36: Update documentation to reflect actual attribute names in dataset classes.Lines 34-35 reference outdated attribute names:
formatted_dsdoes not exist in the codebase; dataset classes usedatasetinsteadtask_specis used in eval dataset classes as aTaskDataSpecobject, while response dataset classes usetask_name(string) andtask_specis not presentReview actual implementations (e.g.,
nemo_rl/data/datasets/response_datasets/squad.py,raw_dataset.py) and update the documentation to reflect the correct attributes and their types.
🤖 Fix all issues with AI agents
In `@docs/guides/dpo.md`:
- Line 121: Update the documentation sentence describing BinaryPreferenceDataset
to use the hyphenated compound adjective: change "single turn completions" to
"single-turn completions" in the paragraph that references
BinaryPreferenceDataset (and mentions PreferenceDataset and the fields
prompt_key, chosen_key, rejected_key) so it reads "pairwise ranked preference
with single-turn completions."
In `@docs/guides/rm.md`:
- Line 110: Update the documentation string to hyphenate "single turn
completions" to "single-turn completions" where the BinaryPreferenceDataset is
described (reference: BinaryPreferenceDataset class and the sentence mentioning
"pairwise ranked preference with single turn completions"); edit the sentence to
read "pairwise ranked preference with single-turn completions" so the phrase is
consistently hyphenated for readability.
In `@examples/run_sft.py`:
- Around line 58-62: Replace the assert in setup_data with an explicit
validation that checks for "train" in data_config and raises a ValueError with a
clear message if missing; update the guard in the setup_data function to use an
if-not check and raise ValueError (preserving the existing explanatory text) so
the check cannot be skipped under Python -O optimization.
- Around line 86-126: Before calling concatenate_datasets on val_data_list,
validate that all datasets share a compatible schema (e.g., same column
names/types) and fail fast with a clear error if they don't; implement this by
inspecting each dataset's column names/features (use the dataset.column_names or
dataset.features/feature keys) and comparing to the first entry in
val_data_list, and if any mismatch is found raise a ValueError that names the
offending dataset types or their differing columns (this check should be added
in the block that builds merged_val_data just before calling
concatenate_datasets in run_sft.py, referencing val_data_list and
concatenate_datasets and mentioning the potential mismatch between
Tulu3SftMixtureDataset.format_data and ResponseDataset.format_data).
In `@examples/run_vlm_grpo.py`:
- Around line 72-77: The env creation loop is hardcoding env_name="vlm" so every
entry in envs uses the wrong name; update the dict comprehension that builds
envs to pass the loop variable into create_env (use env_name as the env_name
argument) and ensure it uses env_configs[env_name] for env_config so each
environment is created with its correct name and config (refer to env_name_list,
create_env, envs, and env_configs).
In `@nemo_rl/data/__init__.py`:
- Around line 18-62: The TypedDicts ResponseDatasetConfig,
PreferenceDatasetConfig, and DataConfig were extended with new keys (e.g.,
dataset_name, data_path, input_key, output_key, env_name, prompt_file,
system_prompt_file, split_validation_size, seed, processor, download_dir,
add_bos/add_eos/add_generation_prompt/add_system_prompt, shuffle, num_workers,
train/validation/default) but lack per-key documentation and example defaults;
add concise class docstrings to ResponseDatasetConfig, PreferenceDatasetConfig,
and DataConfig that describe the purpose of each config block, then add inline
per-field comments or docstrings describing each key’s purpose, valid values,
and recommended defaults (for example: dataset_name: string id of dataset,
data_path: local path or remote URI,
input_key/output_key/prompt_key/chosen_key/rejected_key: JSON field names,
split: train/val/test string, prompt_file/system_prompt_file: path or null,
split_validation_size: float proportion, seed: int, processor: legacy marker,
num_workers: int), and finally update the exemplar YAMLs in
examples/configs/*.yaml to include these keys with sensible default values and
notes so examples reflect the new schema; reference the TypedDict class names
above when making edits.
In `@nemo_rl/data/datasets/preference_datasets/helpsteer3.py`:
- Around line 49-56: The code sets chosen and rejected to the same completion
when overall_preference == 0 (variables overall_preference, chosen, rejected,
response_1), which yields useless/unstable DPO training; instead either filter
these examples out during dataset loading (skip adding the sample when
overall_preference == 0) or mark them with a per-example flag or loss_multiplier
= 0 so downstream training ignores them; update the branch handling
overall_preference == 0 in helpsteer3.py to not append identical chosen/rejected
pairs but to drop the example or set loss_multiplier = 0 and propagate that
field for the trainer to skip.
In `@nemo_rl/data/datasets/raw_dataset.py`:
- Around line 31-37: split_train_validation currently only sets self.val_dataset
when test_size > 0, so calling it with test_size <= 0 leaves a previous
validation split in self.val_dataset; update split_train_validation to
explicitly clear/None out self.val_dataset when test_size <= 0 (in the
split_train_validation method) so stale validation data cannot persist between
calls.
- Around line 39-44: The method set_processor currently assigns a non-None
default ("default") to processor_name in code; remove this in-code default and
instead require "processor" to be present in self.data_config (or ensure it's
set via YAML defaults) by checking for the key and raising a clear error if
missing; update set_processor to read processor_name =
self.data_config["processor"] (with the existing pyrefly comment if needed) and
raise a ValueError with an explanatory message referencing set_processor,
processor_name and self.data_config when the key is absent so callers know to
provide the value via config/YAML.
In `@nemo_rl/data/datasets/response_datasets/geometry3k.py`:
- Line 63: The __init__ signature declares an unused **kwargs which triggers
lint ARG002; to silence the lint without changing behavior, rename it to
**_kwargs (or remove it if not needed) in the __init__ method so the parameter
is recognized as intentionally unused; update the __init__ definition (the
constructor function) accordingly.
In `@nemo_rl/data/datasets/response_datasets/oai_format_dataset.py`:
- Around line 171-176: The loader currently assumes JSONL by reading per-line
into original_dataset; change it to detect and support both JSON array and
JSONL: open data_path, read the file (or peek first non-whitespace char) and if
it begins with '[' or parses as a JSON array, use json.load to get a list,
otherwise fall back to line-by-line json.loads; then continue to call
self.format_data on each item and pass the result into PreservingDataset as
before. Ensure you reference the existing symbols original_dataset, data_path,
self.format_data, and PreservingDataset when implementing the conditional
parsing.
In `@nemo_rl/data/datasets/response_datasets/openmathinstruct2.py`:
- Line 38: The function signature currently includes an unused parameter kwargs
which triggers Ruff ARG002; either remove kwargs if truly unnecessary or silence
the lint by renaming it to _kwargs (or _unused_kwargs) or add a noqa comment.
Locate the function/method that contains the '**kwargs,' parameter in
openmathinstruct2.py (search for the signature containing **kwargs) and update
the signature accordingly and, if renaming, update any internal references or
callers if present; if choosing noqa, attach the specific noqa directive to the
parameter or function definition to suppress ARG002.
In `@nemo_rl/data/datasets/response_datasets/squad.py`:
- Around line 29-31: The constructor in the SQuAD dataset class (__init__ in
nemo_rl/data/datasets/response_datasets/squad.py) currently hard-codes split:
str = "train"; remove that default so split is required (change signature to
split: str) or accept Optional[str] and explicitly validate for None and raise a
clear ValueError instructing callers to provide the split via config/YAML;
update any internal uses of self.split or callers to reflect the required
parameter and ensure tests/configs supply the split from YAML rather than
relying on a code default.
In `@nemo_rl/data/datasets/response_datasets/tulu3.py`:
- Around line 33-36: Remove hard-coded non-None defaults for
split_validation_size and seed in the function/class signature (change
split_validation_size and seed to accept None by default) so defaults are
provided via YAML, and stop silently ignoring **kwargs by either forwarding them
to the RawDataset constructor (e.g., include **kwargs when calling
RawDataset.__init__/super().__init__) or removing **kwargs entirely; update the
signature that currently contains split_validation_size, seed, max_samples,
**kwargs and ensure any call to RawDataset (or RawDataset.__init__) passes
**kwargs if you keep it.
- Around line 63-70: The ValueError in format_data currently includes full
message payloads (messages) and may leak PII; change the exception to include a
redacted summary instead: report messages count and the list of roles (e.g.,
[m["role"] for m in messages]) or a truncated indicator, rather than embedding
the full messages dict; update the raise in format_data to construct and raise
the ValueError with that summary and keep the same check using the messages
variable and assistant role.
In `@nemo_rl/data/multimodal_utils.py`:
- Around line 199-203: The requests.get call that handles URL images (the branch
checking image_path_or_image.startswith(("http://", "https://"))) must include a
timeout to avoid indefinite hangs; update the call in multimodal_utils.py to
pass a timeout (e.g., timeout=10 or a configurable constant like
REQUEST_TIMEOUT) to requests.get(image_path_or_image, timeout=...), and
propagate this constant or parameter where appropriate so the image-loading
function (the block using image_path_or_image and requests.get) uses a bounded
timeout for network calls.
In `@tests/unit/data/datasets/test_response_dataset.py`:
- Around line 27-43: In create_sample_data, replace the unsafe tempfile.mktemp()
call used to set data_path with tempfile.mkdtemp() so a real temporary directory
is created before calling Dataset.from_list(...).save_to_disk(data_path); ensure
data_path (the value returned by tempfile.mkdtemp) is passed directly to
dataset.save_to_disk and remove any reliance on an unused filename; keep the
rest of the function (the NamedTemporaryFile branch and returned data_path)
unchanged.
🧹 Nitpick comments (25)
nemo_rl/data/datasets/response_datasets/response_dataset.py (2)
48-48: Remove or document unused**kwargsparameter.The
kwargsparameter is accepted but never used. If it's intended for forward compatibility or to absorb extra configuration keys from a parent class, document this in the docstring. Otherwise, remove it to avoid confusion.Option 1: Remove if not needed
def __init__( self, data_path: str, input_key: str = "input", output_key: str = "output", split: Optional[str] = None, split_validation_size: float = 0, seed: int = 42, - **kwargs, ):Option 2: Document if intentional
Add to the docstring Args section:
**kwargs: Additional keyword arguments (absorbed for forward compatibility)
52-52: Task name extraction may be fragile for certain path formats.The current string manipulation
data_path.split("/")[-1].split(".")[0]can produce unexpected results for:
- Windows paths with backslashes
- Paths ending with
/(returns empty string)- HuggingFace dataset names without file extensions
Consider using
pathlibfor more robust path handling:Suggested improvement
+from pathlib import Path + class ResponseDataset(RawDataset): ... def __init__( ... ): self.input_key = input_key self.output_key = output_key - self.task_name = data_path.split("/")[-1].split(".")[0] + self.task_name = Path(data_path).stem or Path(data_path).namenemo_rl/data/datasets/response_datasets/oasst.py (1)
69-86: LGTM with a note on the TODO.The function logic is correct. There's a TODO comment (lines 70-71) about the multi-conversation format change—consider addressing or creating an issue to track this.
Would you like me to open an issue to track the TODO about the multi-conversation format?
nemo_rl/data/datasets/response_datasets/squad.py (1)
33-38: Load only the requested split and makekwargsmeaningful.
load_dataset(...)[split]loads all splits. Prefersplit=splitto avoid extra I/O, and optionally forwardkwargsso they are used (resolves the unused-kwargs lint).♻️ Proposed fix
- self.dataset = load_dataset("rajpurkar/squad")[split] + self.dataset = load_dataset("rajpurkar/squad", split=split, **kwargs)nemo_rl/data/datasets/response_datasets/tulu3.py (1)
23-29: Use Google-style docstring with typed Args.The Args section should include types for Google-style docstrings (e.g.,
split_validation_size (float): ...).✍️ Suggested docstring tweak
- Args: - split_validation_size: Size of the validation data, default is 0.05 - seed: Seed for train/validation split when split_validation_size > 0, default is 42 - max_samples: Optional maximum number of samples to use from the dataset + Args: + split_validation_size (float): Size of the validation data; default is defined in YAML. + seed (int): Seed for train/validation split when split_validation_size > 0; default is defined in YAML. + max_samples (int | None): Optional maximum number of samples to use from the dataset.As per coding guidelines, use Google-style docstrings.
nemo_rl/data/datasets/processed_dataset.py (1)
59-60: Note: TODO for preference dataset refactor.The TODO indicates that
default_task_data_spechandling will be cleaned up once the preference dataset is refactored. This is acceptable transitional code.Would you like me to open an issue to track this refactoring task?
nemo_rl/data/multimodal_utils.py (1)
207-209: Prefix unused variable with underscore.The
headervariable from the split is never used. Per Ruff hint RUF059, prefix it with an underscore to indicate it's intentionally unused.Proposed fix
# Handle base64 encoded image # Format: data:image/jpeg;base64,/9j/4AAQSkZJRg... - header, encoded = image_path_or_image.split(",", 1) + _header, encoded = image_path_or_image.split(",", 1) image_data = base64.b64decode(encoded) return Image.open(BytesIO(image_data)).convert("RGB")nemo_rl/data/datasets/preference_datasets/tulu3.py (1)
40-52: Consider using a more efficient comparison thanjson.dumps.The
json.dumpscomparison for context validation works but has performance overhead and may fail on edge cases with non-JSON-serializable content. A direct list comparison could be more robust:♻️ Suggested alternative
- assert json.dumps(context, ensure_ascii=False) == json.dumps( - rejected_conversation[:-1], ensure_ascii=False - ), ( + assert context == rejected_conversation[:-1], ( f"Context mismatch.\n\nchosen: {chosen_conversation}\n\n rejected: {rejected_conversation}" )If order-independent comparison is needed, the current approach is acceptable. Otherwise, direct equality should work for lists of dicts with the same structure.
nemo_rl/data/datasets/response_datasets/deepscaler.py (1)
25-37: Consider documenting or suppressing the unusedkwargsparameter.The
**kwargsparameter is unused but likely intentional for API consistency across dataset classes. Consider either:
- Adding a comment explaining it's for interface compatibility
- Using
**_to explicitly indicate unused- Adding
# noqa: ARG002if you want to suppress the linter warning♻️ Suggested fix
- def __init__(self, **kwargs) -> None: + def __init__(self, **_kwargs) -> None: # kwargs accepted for API consistencynemo_rl/data/datasets/response_datasets/helpsteer3.py (1)
30-40: Consider documenting or suppressing the unusedkwargsparameter.Same as other dataset classes -
kwargsis accepted for API consistency but unused.♻️ Suggested fix
- def __init__(self, split: str = "train", **kwargs): + def __init__(self, split: str = "train", **_kwargs): # kwargs for API consistencynemo_rl/data/datasets/preference_datasets/helpsteer3.py (1)
29-39: Consider documenting or suppressing the unusedkwargsparameter.♻️ Suggested fix
- def __init__(self, split: str = "train", **kwargs): + def __init__(self, split: str = "train", **_kwargs): # kwargs for API consistencyexamples/run_grpo.py (1)
77-82: PotentialKeyErrorifenv_configsis missing an environment.If an
env_nameextracted fromdata_configdoesn't exist inenv_configs, this will raise aKeyErrorwith minimal context. Consider adding a check with a descriptive error message.♻️ Suggested improvement
env_name_list = extract_necessary_env_names(data_config) + missing_envs = [name for name in env_name_list if name not in env_configs] + if missing_envs: + raise ValueError( + f"Environment config(s) missing for: {missing_envs}. " + f"Available env configs: {list(env_configs.keys())}" + ) envs = { env_name: create_env(env_name=env_name, env_config=env_configs[env_name]) for env_name in env_name_list }examples/run_grpo_math.py (1)
62-140: Near-identical implementation torun_grpo.py.The
setup_datafunction in this file is essentially identical torun_grpo.py. Consider extracting this into a shared utility to avoid maintaining duplicate logic. This could be addressed in a follow-up PR.examples/run_dpo.py (1)
53-68: Missing assertion fordata_config["train"]presence.Unlike
run_grpo.pyandrun_grpo_math.py, this file doesn't assert that"train"exists indata_configbefore accessing it. Consider adding a similar assertion with a helpful migration message for consistency.♻️ Suggested addition
def setup_data(tokenizer: AutoTokenizer, data_config: DataConfig): print("\n▶ Setting up data...") + assert "train" in data_config, ( + "The dataset config structure is updated. Please refer to the docs " + "and the Migrate Guide to update the dataset config." + ) # setup train dataset if "default" in data_config:nemo_rl/data/datasets/raw_dataset.py (1)
23-29: Stale comment no longer matches the type.Line 23 says to change to a union “once preference dataset is refactored,” but the union type is already in place. Consider removing or updating the comment to avoid confusion.
examples/configs/recipes/llm/sft-llama3.1-8b-1n8g-fsdp2tp1-dynamicbatch.yaml (1)
47-47: Prefer interpolating the dataset name in TensorBoard paths.Right now the log dir is hard-coded; using interpolation prevents stale paths if the dataset changes.
♻️ Suggested change
- log_dir: tb_logs-sft-dev-openmathinstruct2 + log_dir: tb_logs-sft-dev-${data.train.dataset_name}nemo_rl/data/datasets/response_datasets/aime24.py (1)
22-42: Consider initializingval_datasetfor consistency with the base class.The
RawDatasetbase class defines aval_datasetattribute andsplit_train_validationmethod. Other datasets likeResponseDatasetandOpenMathInstruct2Datasetinitializeself.val_dataset = Noneand optionally callself.split_train_validation(...)to support using a single dataset for both train and validation.If this dataset should support the
split_validation_sizefeature mentioned in the docs, consider adding:self.val_dataset = None # Optionally support train/val split if needed in the futuretests/unit/data/datasets/test_oai_format_dataset.py (1)
49-53: Tokenizer fixture downloads from network - consider marking test appropriately.The
get_tokenizer({"name": "Qwen/Qwen3-0.6B"})call will download the tokenizer from HuggingFace on first run. This may cause test failures in CI environments without network access or slow down local test runs.Consider either:
- Using a mock tokenizer for unit tests
- Marking this test with
@pytest.mark.networkor similar to skip in offline environments- Using a cached/local tokenizer path
nemo_rl/data/datasets/preference_datasets/preference_dataset.py (1)
45-45: Task name extraction may fail on certain path formats.The expression
data_path.split("/")[-1].split(".")[0]assumes Unix-style paths with file extensions. This could produce unexpected results for:
- Windows paths:
C:\data\file.json→ task_name would beC:\data\file- HuggingFace dataset IDs without extensions:
nvidia/helpsteer3→ task_name would behelpsteer3(likely acceptable)- Paths without extensions:
/path/to/dataset→ task_name would bedataset(likely acceptable)Consider using
pathlib.Pathoros.pathfor more robust handling:from pathlib import Path self.task_name = Path(data_path).stem or Path(data_path).namenemo_rl/data/datasets/response_datasets/__init__.py (1)
37-64: ConsiderG_prefix for the global registry name.Guidelines call for a
G_prefix on global variables; renaming here keeps convention consistent.♻️ Suggested refactor
-DATASET_REGISTRY = { +G_DATASET_REGISTRY = { @@ - if dataset_name in DATASET_REGISTRY: - dataset_class = DATASET_REGISTRY[dataset_name] + if dataset_name in G_DATASET_REGISTRY: + dataset_class = G_DATASET_REGISTRY[dataset_name]As per coding guidelines, global variables should use a
G_prefix.nemo_rl/data/datasets/response_datasets/refcoco.py (1)
174-178: Consider documenting the**kwargsparameter.The
kwargsparameter is unused but appears intentionally included for API consistency with other dataset classes that pass extra config options. Consider adding a brief docstring note explaining this is for interface compatibility.📝 Suggested documentation
def __init__( self, split: str = "train", download_dir: str = "./coco_images", - **kwargs, + **kwargs, # Accepts extra config options for interface compatibility ):nemo_rl/data/datasets/response_datasets/dapo_math.py (1)
53-65: Subclass overrides__init__without callingsuper().
DAPOMathAIME2024Datasetinherits fromDAPOMath17KDatasetbut completely reimplements__init__without callingsuper().__init__(). This makes the inheritance relationship misleading since onlyformat_datais actually inherited.Consider either:
- Using composition instead of inheritance
- Refactoring the base class to accept parameters for customization
♻️ Option 1: Refactor base class to accept parameters
class DAPOMath17KDataset(RawDataset): - """Simple wrapper around the DAPO Math 17K dataset with train split.""" + """Simple wrapper around the DAPO Math datasets.""" - def __init__(self, **kwargs) -> None: - self.task_name = "DAPOMath17K" - - # load from huggingface - self.dataset = load_dataset("BytedTsinghua-SIA/DAPO-Math-17k", split="train") + def __init__( + self, + task_name: str = "DAPOMath17K", + hf_dataset: str = "BytedTsinghua-SIA/DAPO-Math-17k", + split: str = "train", + **kwargs, + ) -> None: + self.task_name = task_name + self.dataset = load_dataset(hf_dataset, split=split) # format the dataset self.dataset = self.dataset.map( self.format_data, remove_columns=self.dataset.column_names, ) class DAPOMathAIME2024Dataset(DAPOMath17KDataset): def __init__(self, **kwargs) -> None: """Initialize the DAPO Math AIME 2024 dataset with train split.""" - self.task_name = "DAPOMathAIME2024" - - # load from huggingface - self.dataset = load_dataset("BytedTsinghua-SIA/AIME-2024", split="train") - - # format the dataset - self.dataset = self.dataset.map( - self.format_data, - remove_columns=self.dataset.column_names, + super().__init__( + task_name="DAPOMathAIME2024", + hf_dataset="BytedTsinghua-SIA/AIME-2024", + split="train", + **kwargs, )examples/configs/rm.yaml (1)
157-161: Consider clarifying the deprecation timeline forval_data_paths.The comment mentions
val_data_paths"will be removed after refactor." It would be helpful to add a reference to the tracking issue or provide guidance on what to use instead when this feature is deprecated.docs/guides/sft.md (1)
122-132: Consider expanding the validation example.The validation block shows only
...which might leave users uncertain about how to configure validation for OpenAI format datasets. Consider adding a minimal example or a note referring to the general pattern shown earlier.nemo_rl/data/datasets/response_datasets/openmathinstruct2.py (1)
32-37: Avoid hard-coded defaults for dataset config values.
output_key,split_validation_size, andseedare config-driven; consider making them explicit in YAML and removing non-None defaults to avoid hidden behavior (e.g., auto 5% split). As per coding guidelines, defaults should live in YAML.
| <NameOfValidationDataset2>: /path/to/local/val_dataset_2.jsonl | ||
| ``` | ||
| We also provide a [BinaryPreferenceDataset](../../nemo_rl/data/datasets/preference_datasets/binary_preference_dataset.py) class, which is a simplified version of PreferenceDataset for pairwise ranked preference with single turn completions. You can use `prompt_key`, `chosen_key` and `rejected_key` to specify which fields in your data correspond to the question, chosen answer and rejected answer respectively. Here's an example configuration: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor grammar: use hyphen in compound adjective.
"single turn completions" should be "single-turn completions" when used as a compound adjective modifying "completions".
📝 Suggested fix
-We also provide a [BinaryPreferenceDataset](../../nemo_rl/data/datasets/preference_datasets/binary_preference_dataset.py) class, which is a simplified version of PreferenceDataset for pairwise ranked preference with single turn completions.
+We also provide a [BinaryPreferenceDataset](../../nemo_rl/data/datasets/preference_datasets/binary_preference_dataset.py) class, which is a simplified version of PreferenceDataset for pairwise ranked preference with single-turn completions.📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| We also provide a [BinaryPreferenceDataset](../../nemo_rl/data/datasets/preference_datasets/binary_preference_dataset.py) class, which is a simplified version of PreferenceDataset for pairwise ranked preference with single turn completions. You can use `prompt_key`, `chosen_key` and `rejected_key` to specify which fields in your data correspond to the question, chosen answer and rejected answer respectively. Here's an example configuration: | |
| We also provide a [BinaryPreferenceDataset](../../nemo_rl/data/datasets/preference_datasets/binary_preference_dataset.py) class, which is a simplified version of PreferenceDataset for pairwise ranked preference with single-turn completions. You can use `prompt_key`, `chosen_key` and `rejected_key` to specify which fields in your data correspond to the question, chosen answer and rejected answer respectively. Here's an example configuration: |
🧰 Tools
🪛 LanguageTool
[grammar] ~121-~121: Use a hyphen to join words.
Context: ...r pairwise ranked preference with single turn completions. You can use `prompt_ke...
(QB_NEW_EN_HYPHEN)
🤖 Prompt for AI Agents
In `@docs/guides/dpo.md` at line 121, Update the documentation sentence describing
BinaryPreferenceDataset to use the hyphenated compound adjective: change "single
turn completions" to "single-turn completions" in the paragraph that references
BinaryPreferenceDataset (and mentions PreferenceDataset and the fields
prompt_key, chosen_key, rejected_key) so it reads "pairwise ranked preference
with single-turn completions."
| <NameOfValidationDataset2>: /path/to/local/val_dataset_2.jsonl | ||
| ``` | ||
| We also provide a [BinaryPreferenceDataset](../../nemo_rl/data/datasets/preference_datasets/binary_preference_dataset.py) class, which is a simplified version of PreferenceDataset for pairwise ranked preference with single turn completions. You can use `prompt_key`, `chosen_key` and `rejected_key` to specify which fields in your data correspond to the question, chosen answer and rejected answer respectively. Here's an example configuration: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hyphenate “single-turn” for readability.
✏️ Proposed fix
- pairwise ranked preference with single turn completions.
+ pairwise ranked preference with single-turn completions.📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| We also provide a [BinaryPreferenceDataset](../../nemo_rl/data/datasets/preference_datasets/binary_preference_dataset.py) class, which is a simplified version of PreferenceDataset for pairwise ranked preference with single turn completions. You can use `prompt_key`, `chosen_key` and `rejected_key` to specify which fields in your data correspond to the question, chosen answer and rejected answer respectively. Here's an example configuration: | |
| We also provide a [BinaryPreferenceDataset](../../nemo_rl/data/datasets/preference_datasets/binary_preference_dataset.py) class, which is a simplified version of PreferenceDataset for pairwise ranked preference with single-turn completions. You can use `prompt_key`, `chosen_key` and `rejected_key` to specify which fields in your data correspond to the question, chosen answer and rejected answer respectively. Here's an example configuration: |
🧰 Tools
🪛 LanguageTool
[grammar] ~110-~110: Use a hyphen to join words.
Context: ...r pairwise ranked preference with single turn completions. You can use `prompt_ke...
(QB_NEW_EN_HYPHEN)
🤖 Prompt for AI Agents
In `@docs/guides/rm.md` at line 110, Update the documentation string to hyphenate
"single turn completions" to "single-turn completions" where the
BinaryPreferenceDataset is described (reference: BinaryPreferenceDataset class
and the sentence mentioning "pairwise ranked preference with single turn
completions"); edit the sentence to read "pairwise ranked preference with
single-turn completions" so the phrase is consistently hyphenated for
readability.
| def setup_data(tokenizer: AutoTokenizer, data_config: DataConfig): | ||
| assert "train" in data_config, ( | ||
| "The dataset config structure is updated. Please refer to https://github.com/NVIDIA-NeMo/RL/blob/main/docs/guides/sft.md#datasets " | ||
| "and the Migrate Guide in https://github.com/NVIDIA-NeMo/RL/pull/1649 to update the dataset config." | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
# First, let's examine the file at the specified lines
cat -n examples/run_sft.py | sed -n '55,65p'Repository: NVIDIA-NeMo/RL
Length of output: 616
🏁 Script executed:
# Let's also check if there are other assert statements used for config validation in the codebase
rg "assert.*in.*config" --type py -A 2 -B 2Repository: NVIDIA-NeMo/RL
Length of output: 7952
Avoid assert for config validation.
Assertions are stripped with Python's -O flag, which would skip the guard and cause opaque failures later when the missing config key is accessed. Use an explicit check with raise ValueError instead.
🐛 Proposed fix
- assert "train" in data_config, (
- "The dataset config structure is updated. Please refer to https://github.com/NVIDIA-NeMo/RL/blob/main/docs/guides/sft.md#datasets "
- "and the Migrate Guide in https://github.com/NVIDIA-NeMo/RL/pull/1649 to update the dataset config."
- )
+ if "train" not in data_config:
+ raise ValueError(
+ "The dataset config structure is updated. Please refer to "
+ "https://github.com/NVIDIA-NeMo/RL/blob/main/docs/guides/sft.md#datasets "
+ "and the Migrate Guide in https://github.com/NVIDIA-NeMo/RL/pull/1649 "
+ "to update the dataset config."
+ )🤖 Prompt for AI Agents
In `@examples/run_sft.py` around lines 58 - 62, Replace the assert in setup_data
with an explicit validation that checks for "train" in data_config and raises a
ValueError with a clear message if missing; update the guard in the setup_data
function to use an if-not check and raise ValueError (preserving the existing
explanatory text) so the check cannot be skipped under Python -O optimization.
| # setup validation dataset | ||
| val_task_data_processors = {} | ||
| val_data_list = [] | ||
|
|
||
| # validation dataset from train dataset (when train dataset's split_validation_size > 0) | ||
| if hasattr(data, "val_dataset") and data.val_dataset is not None: | ||
| val_data_list.append(data.val_dataset) | ||
| val_task_data_processors = task_data_processors.copy() | ||
|
|
||
| # validation dataset from config | ||
| if "validation" in data_config and data_config["validation"] is not None: | ||
| if "default" in data_config: | ||
| update_single_dataset_config( | ||
| data_config["validation"], data_config["default"] | ||
| ) | ||
| val_data = load_response_dataset(data_config["validation"]) | ||
| val_data_list.append(val_data.dataset) | ||
| val_data_processor = partial( | ||
| val_data.processor, | ||
| add_bos=data_config["add_bos"], | ||
| add_eos=data_config["add_eos"], | ||
| add_generation_prompt=data_config["add_generation_prompt"], | ||
| datum_preprocessor=datum_preprocessor, | ||
| ), | ||
| max_seq_length=data_config["max_input_seq_length"], | ||
| ) | ||
| ) | ||
| val_task_data_processors[val_data.task_name] = ( | ||
| val_data.task_spec, | ||
| val_data_processor, | ||
| ) | ||
|
|
||
| if val_dataset is not None: | ||
| val_dataset = None | ||
| if len(val_data_list) > 0: | ||
| merged_val_data = concatenate_datasets(val_data_list) | ||
| val_dataset = AllTaskProcessedDataset( | ||
| val_dataset, | ||
| merged_val_data, | ||
| tokenizer, | ||
| sft_task_spec, | ||
| partial( | ||
| sft_preprocessor, | ||
| add_bos=data_config.get("add_bos", True), | ||
| add_eos=data_config.get("add_eos", True), | ||
| add_generation_prompt=data_config["add_generation_prompt"], | ||
| datum_preprocessor=datum_preprocessor, | ||
| ), | ||
| None, | ||
| val_task_data_processors, | ||
| max_seq_length=data_config["max_input_seq_length"], | ||
| ) | ||
| print(f" ✓ Validation dataset loaded with {len(val_dataset)} samples.") | ||
|
|
||
| return train_dataset, val_dataset, sft_task_spec | ||
| return dataset, val_dataset |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
# First, let's find the file and examine the relevant lines
fd examples/run_sft.pyRepository: NVIDIA-NeMo/RL
Length of output: 172
🏁 Script executed:
# Let's examine the specific lines mentioned in the review
sed -n '86,126p' examples/run_sft.pyRepository: NVIDIA-NeMo/RL
Length of output: 1640
🏁 Script executed:
# Find where concatenate_datasets is imported from
grep -n "concatenate_datasets" examples/run_sft.pyRepository: NVIDIA-NeMo/RL
Length of output: 168
🏁 Script executed:
# Search for the definition of concatenate_datasets in the codebase
rg "def concatenate_datasets" --type pyRepository: NVIDIA-NeMo/RL
Length of output: 40
🏁 Script executed:
# Let's also check if it's from a library
rg "from.*concatenate_datasets\|import.*concatenate_datasets" --type pyRepository: NVIDIA-NeMo/RL
Length of output: 40
🏁 Script executed:
# Find load_response_dataset function to understand what schema it provides
rg "def load_response_dataset" --type py -A 10Repository: NVIDIA-NeMo/RL
Length of output: 1047
🏁 Script executed:
# Search for where val_dataset is created/loaded to understand its schema
rg "val_dataset\s*=" --type py -B 3 -A 3 | head -60Repository: NVIDIA-NeMo/RL
Length of output: 2933
🏁 Script executed:
# Look at the entire run_sft.py file to understand data flow
wc -l examples/run_sft.pyRepository: NVIDIA-NeMo/RL
Length of output: 81
🏁 Script executed:
# Get more context around the validation dataset setup
sed -n '1,100p' examples/run_sft.py | head -80Repository: NVIDIA-NeMo/RL
Length of output: 2921
🏁 Script executed:
# Get the rest of the setup_data function starting from line 60
sed -n '60,130p' examples/run_sft.pyRepository: NVIDIA-NeMo/RL
Length of output: 2740
🏁 Script executed:
# Check the schema/features of datasets returned by load_response_dataset
rg "class.*Dataset" nemo_rl/data/datasets/response_datasets/ --type py | head -20Repository: NVIDIA-NeMo/RL
Length of output: 1423
🏁 Script executed:
# Look at what load_response_dataset returns in detail
sed -n '1,50p' nemo_rl/data/datasets/response_datasets/__init__.pyRepository: NVIDIA-NeMo/RL
Length of output: 2329
🏁 Script executed:
# Search for the return statement in load_response_dataset
rg -A 20 "def load_response_dataset" nemo_rl/data/datasets/response_datasets/__init__.pyRepository: NVIDIA-NeMo/RL
Length of output: 928
🏁 Script executed:
# Check RawDataset base class to understand the dataset structure
rg "class RawDataset" --type py -A 30 | head -60Repository: NVIDIA-NeMo/RL
Length of output: 2659
🏁 Script executed:
# Look for val_dataset property/method in dataset classes
rg "val_dataset" nemo_rl/data/datasets/response_datasets/ --type py -B 3 -A 3Repository: NVIDIA-NeMo/RL
Length of output: 3106
🏁 Script executed:
# Check how datasets are structured - look at ResponseDataset
sed -n '1,100p' nemo_rl/data/datasets/response_datasets/response_dataset.pyRepository: NVIDIA-NeMo/RL
Length of output: 3067
🏁 Script executed:
# Check if there's a split_validation_size or similar mechanism
rg "split_validation_size\|val_dataset" --type py | head -30Repository: NVIDIA-NeMo/RL
Length of output: 40
🏁 Script executed:
# Check if different dataset types have different format_data implementations
rg "def format_data" nemo_rl/data/datasets/response_datasets/ --type py -A 10 | head -80Repository: NVIDIA-NeMo/RL
Length of output: 7154
🏁 Script executed:
# Check the schema of datasets after format_data is applied
rg "\"messages\"" nemo_rl/data/datasets/response_datasets/ --type py -B 3 -A 3 | head -50Repository: NVIDIA-NeMo/RL
Length of output: 3988
🏁 Script executed:
# Look at how AllTaskProcessedDataset expects datasets to be formatted
rg "class AllTaskProcessedDataset" --type py -A 30Repository: NVIDIA-NeMo/RL
Length of output: 2844
🏁 Script executed:
# Check if there's any existing validation for dataset schema compatibility
rg "features" nemo_rl/data/datasets/ --type py | grep -i "concat\|merge\|validate" | head -20Repository: NVIDIA-NeMo/RL
Length of output: 40
🏁 Script executed:
# Look at test files to see if this scenario is tested
fd "test.*\.py" examples/ nemo_rl/ | xargs grep -l "concatenate_datasets\|val.*dataset" | head -10Repository: NVIDIA-NeMo/RL
Length of output: 40
🏁 Script executed:
# Check if Tulu3 dataset is commonly used together with other validation datasets
rg "tulu3\|Tulu3" examples/ --type py -B 2 -A 2Repository: NVIDIA-NeMo/RL
Length of output: 40
🏁 Script executed:
# Verify what the actual error would be from HuggingFace
python3 << 'EOF'
# Check HuggingFace datasets library behavior
try:
from datasets import Dataset, concatenate_datasets
# Simulate two datasets with different schemas
ds1 = Dataset.from_dict({"task_name": ["task1"], "messages": [{"role": "user", "content": "hi"}]})
ds2 = Dataset.from_dict({"task_name": ["task2"]})
print("Dataset 1 features:", ds1.features)
print("Dataset 2 features:", ds2.features)
print("\nAttempting to concatenate...")
try:
result = concatenate_datasets([ds1, ds2])
print("Success!")
except Exception as e:
print(f"Error type: {type(e).__name__}")
print(f"Error message: {e}")
except ImportError:
print("datasets library not available in this environment")
EOFRepository: NVIDIA-NeMo/RL
Length of output: 108
🏁 Script executed:
# Look more carefully at dataset format_data to understand the schema issue
cat nemo_rl/data/datasets/response_datasets/tulu3.py | grep -A 30 "def format_data"Repository: NVIDIA-NeMo/RL
Length of output: 442
🏁 Script executed:
# Check the actual column names/features after format_data processing
rg "add_column\|remove_columns\|dataset.map" nemo_rl/data/datasets/response_datasets/ --type py -B 2 -A 2 | head -60Repository: NVIDIA-NeMo/RL
Length of output: 40
🏁 Script executed:
# Verify what columns are in val_dataset when it comes from split_train_validation
sed -n '1,50p' nemo_rl/data/datasets/raw_dataset.pyRepository: NVIDIA-NeMo/RL
Length of output: 2276
🏁 Script executed:
# Check if there are any existing guards or validations for schema compatibility elsewhere
rg "\.features\|schema" examples/ --type pyRepository: NVIDIA-NeMo/RL
Length of output: 40
🏁 Script executed:
# Let's examine AllTaskProcessedDataset more closely to see how it handles the dataset
sed -n '1,150p' nemo_rl/data/datasets/processed_dataset.pyRepository: NVIDIA-NeMo/RL
Length of output: 5248
🏁 Script executed:
# Look at how the data is actually processed in AllTaskProcessedDataset.__getitem__
rg "__getitem__" nemo_rl/data/datasets/processed_dataset.py -A 30Repository: NVIDIA-NeMo/RL
Length of output: 1348
🏁 Script executed:
# Check if there's any handling for mismatched schemas in the data loading path
rg "train_test_split" --type py -B 3 -A 5Repository: NVIDIA-NeMo/RL
Length of output: 708
🏁 Script executed:
# Final check: see if this exact scenario (different dataset types in val_data_list) is actually possible
sed -n '86,126p' examples/run_sft.py | cat -nRepository: NVIDIA-NeMo/RL
Length of output: 1927
🏁 Script executed:
# Let's verify Tulu3 dataset processing to confirm schema difference
cat nemo_rl/data/datasets/response_datasets/tulu3.py | sed -n '1,100p'Repository: NVIDIA-NeMo/RL
Length of output: 2693
🏁 Script executed:
# Check how Tulu3 applies format_data
rg -A 40 "class Tulu3SftMixtureDataset" nemo_rl/data/datasets/response_datasets/tulu3.py | head -60Repository: NVIDIA-NeMo/RL
Length of output: 1579
🏁 Script executed:
# Verify if Tulu3 dataset goes through the same format_data mapping as ResponseDataset
grep -A 15 "def __init__" nemo_rl/data/datasets/response_datasets/tulu3.py | head -30Repository: NVIDIA-NeMo/RL
Length of output: 577
🏁 Script executed:
# Let's check what columns each dataset type actually has after format_data
# First, check the exact columns that Tulu3 keeps
cat nemo_rl/data/datasets/response_datasets/tulu3.py | sed -n '45,65p'Repository: NVIDIA-NeMo/RL
Length of output: 852
🏁 Script executed:
# Verify what columns ResponseDataset has
cat nemo_rl/data/datasets/response_datasets/response_dataset.py | sed -n '50,70p'Repository: NVIDIA-NeMo/RL
Length of output: 895
🏁 Script executed:
# Check if there are any integration tests that cover merging different dataset types
fd "test.*\.py" | xargs grep -l "val_data\|validation" | head -5Repository: NVIDIA-NeMo/RL
Length of output: 236
🏁 Script executed:
# Let's trace through one more dataset type to confirm the pattern
cat nemo_rl/data/datasets/response_datasets/squad.py | sed -n '1,80p'Repository: NVIDIA-NeMo/RL
Length of output: 1864
🏁 Script executed:
# Let's verify the exact schema issue by looking at what columns are kept in each format_data
# Tulu3 - check what columns remain after map with format_data
cat nemo_rl/data/datasets/response_datasets/tulu3.py | grep -A 15 "def format_data"Repository: NVIDIA-NeMo/RL
Length of output: 442
🏁 Script executed:
# ResponseDataset - what columns it returns
cat nemo_rl/data/datasets/response_datasets/response_dataset.py | grep -A 15 "def format_data"Repository: NVIDIA-NeMo/RL
Length of output: 373
🏁 Script executed:
# Let's check the test file to see if this scenario is covered
cat tests/unit/algorithms/test_sft.py | head -100Repository: NVIDIA-NeMo/RL
Length of output: 3063
🏁 Script executed:
# Check if there are any comments in the code about this schema incompatibility
rg "schema\|feature\|incompatible" nemo_rl/data/ examples/ --type pyRepository: NVIDIA-NeMo/RL
Length of output: 40
🏁 Script executed:
# Final verification: Check if there's documentation or issues about this in the repo
rg "concatenate_datasets" --type py -B 2 -A 5 | head -40Repository: NVIDIA-NeMo/RL
Length of output: 2479
🏁 Script executed:
# Check if there are any error handling or validation patterns we should be aware of
rg "ValueError\|schema\|features" nemo_rl/data/datasets/ --type py | grep -i concatRepository: NVIDIA-NeMo/RL
Length of output: 40
🏁 Script executed:
# Verify the exact line numbers in the file one more time to be precise
wc -l examples/run_sft.py && sed -n '110,120p' examples/run_sft.pyRepository: NVIDIA-NeMo/RL
Length of output: 389
Guard validation dataset concatenation against schema mismatches.
Different dataset types produce incompatible schemas after processing. For example, Tulu3SftMixtureDataset.format_data preserves the original "messages" column while ResponseDataset.format_data replaces all columns with newly formatted ones. If the training dataset validation split (from data.val_dataset) and the config-specified validation dataset have different types, concatenate_datasets will fail at runtime with a confusing error. Add an explicit check before concatenation to fail fast with a clear message.
Proposed implementation
if len(val_data_list) > 0:
+ if len(val_data_list) > 1:
+ base_features = val_data_list[0].features
+ for ds in val_data_list[1:]:
+ if ds.features != base_features:
+ raise ValueError(
+ "Validation datasets must share identical features to be concatenated. "
+ "Consider normalizing schemas or loading them separately."
+ )
merged_val_data = concatenate_datasets(val_data_list)🤖 Prompt for AI Agents
In `@examples/run_sft.py` around lines 86 - 126, Before calling
concatenate_datasets on val_data_list, validate that all datasets share a
compatible schema (e.g., same column names/types) and fail fast with a clear
error if they don't; implement this by inspecting each dataset's column
names/features (use the dataset.column_names or dataset.features/feature keys)
and comparing to the first entry in val_data_list, and if any mismatch is found
raise a ValueError that names the offending dataset types or their differing
columns (this check should be added in the block that builds merged_val_data
just before calling concatenate_datasets in run_sft.py, referencing
val_data_list and concatenate_datasets and mentioning the potential mismatch
between Tulu3SftMixtureDataset.format_data and ResponseDataset.format_data).
| print("\n▶ Setting up envs...") | ||
| env_name_list = extract_necessary_env_names(data_config) | ||
| envs = { | ||
| env_name: create_env(env_name="vlm", env_config=env_configs[env_name]) | ||
| for env_name in env_name_list | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bug: Hardcoded env_name="vlm" ignores the actual environment name.
The loop iterates over env_name_list, but the create_env call hardcodes env_name="vlm" instead of using the loop variable. This means all environments will be created with the "vlm" type regardless of what's in the config.
Compare with examples/run_grpo_rm.py line 86 which correctly uses env_name=env_name.
🐛 Proposed fix
envs = {
- env_name: create_env(env_name="vlm", env_config=env_configs[env_name])
+ env_name: create_env(env_name=env_name, env_config=env_configs[env_name])
for env_name in env_name_list
}🤖 Prompt for AI Agents
In `@examples/run_vlm_grpo.py` around lines 72 - 77, The env creation loop is
hardcoding env_name="vlm" so every entry in envs uses the wrong name; update the
dict comprehension that builds envs to pass the loop variable into create_env
(use env_name as the env_name argument) and ensure it uses env_configs[env_name]
for env_config so each environment is created with its correct name and config
(refer to env_name_list, create_env, envs, and env_configs).
| def __init__(self, split: str = "train", **kwargs) -> None: | ||
| self.task_name = "squad" | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don’t hard-code a default split in code; keep defaults in YAML.
Guidelines state configuration defaults should live in YAML. Make split required (or explicitly validate None) and rely on config to supply it.
✅ Proposed fix
- def __init__(self, split: str = "train", **kwargs) -> None:
+ def __init__(self, split: str | None = None, **kwargs) -> None:
+ if split is None:
+ raise ValueError("split must be provided via config")
self.task_name = "squad"🧰 Tools
🪛 Ruff (0.14.13)
29-29: Unused method argument: kwargs
(ARG002)
🤖 Prompt for AI Agents
In `@nemo_rl/data/datasets/response_datasets/squad.py` around lines 29 - 31, The
constructor in the SQuAD dataset class (__init__ in
nemo_rl/data/datasets/response_datasets/squad.py) currently hard-codes split:
str = "train"; remove that default so split is required (change signature to
split: str) or accept Optional[str] and explicitly validate for None and raise a
clear ValueError instructing callers to provide the split via config/YAML;
update any internal uses of self.split or callers to reflect the required
parameter and ensure tests/configs supply the split from YAML rather than
relying on a code default.
| split_validation_size: float = 0.05, | ||
| seed: int = 42, | ||
| test_size: float = 0.05, | ||
| prompt_file: str | None = None, | ||
| max_samples: int | None = None, | ||
| **kwargs, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Move non-None defaults to YAML; don’t swallow kwargs.
split_validation_size and seed have non-None defaults in code, but defaults should live in YAML. Also, **kwargs is currently unused—either forward to RawDataset or remove it to avoid silently ignoring config.
🔧 Proposed fix
- split_validation_size: float = 0.05,
- seed: int = 42,
+ split_validation_size: float,
+ seed: int,
max_samples: int | None = None,
**kwargs,
) -> None:
+ super().__init__(**kwargs)
print(
"WARNING: For reproducible experiments, preprocess the dataset once and define your own HfDataset subclass that directly uses the preprocessed datasets."
)As per coding guidelines, keep configuration defaults in YAML and avoid unused kwargs.
🧰 Tools
🪛 Ruff (0.14.13)
36-36: Unused method argument: kwargs
(ARG002)
🤖 Prompt for AI Agents
In `@nemo_rl/data/datasets/response_datasets/tulu3.py` around lines 33 - 36,
Remove hard-coded non-None defaults for split_validation_size and seed in the
function/class signature (change split_validation_size and seed to accept None
by default) so defaults are provided via YAML, and stop silently ignoring
**kwargs by either forwarding them to the RawDataset constructor (e.g., include
**kwargs when calling RawDataset.__init__/super().__init__) or removing **kwargs
entirely; update the signature that currently contains split_validation_size,
seed, max_samples, **kwargs and ensure any call to RawDataset (or
RawDataset.__init__) passes **kwargs if you keep it.
| def format_data(self, data: dict[str, Any]) -> dict[str, Any]: | ||
| messages = data["messages"] | ||
|
|
||
| # Ensure last message is from assistant | ||
| if not messages or messages[-1]["role"] != "assistant": | ||
| raise ValueError( | ||
| f"Expected last message to be from assistant, got: {messages}" | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Avoid logging full message contents in exceptions.
The error message includes full message payloads, which risks leaking PII into logs. Prefer a redacted summary (count/roles only).
🔒 Proposed fix
- if not messages or messages[-1]["role"] != "assistant":
- raise ValueError(
- f"Expected last message to be from assistant, got: {messages}"
- )
+ if not messages or messages[-1].get("role") != "assistant":
+ last_role = messages[-1].get("role") if messages else None
+ raise ValueError(
+ "Expected last message to be from assistant; "
+ f"message_count={len(messages) if messages else 0}, "
+ f"last_role={last_role}"
+ )🧰 Tools
🪛 Ruff (0.14.13)
68-70: Avoid specifying long messages outside the exception class
(TRY003)
🤖 Prompt for AI Agents
In `@nemo_rl/data/datasets/response_datasets/tulu3.py` around lines 63 - 70, The
ValueError in format_data currently includes full message payloads (messages)
and may leak PII; change the exception to include a redacted summary instead:
report messages count and the list of roles (e.g., [m["role"] for m in
messages]) or a truncated indicator, rather than embedding the full messages
dict; update the raise in format_data to construct and raise the ValueError with
that summary and keep the same check using the messages variable and assistant
role.
| if image_path_or_image.startswith(("http://", "https://")): | ||
| # Handle URL | ||
| response = requests.get(image_path_or_image) | ||
| response.raise_for_status() | ||
| return Image.open(BytesIO(response.content)).convert("RGB") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add timeout to requests.get to prevent indefinite hangs.
The HTTP request lacks a timeout, which can cause the process to hang indefinitely if the remote server is unresponsive. This is a reliability concern, especially in data loading pipelines.
Proposed fix
if image_path_or_image.startswith(("http://", "https://")):
# Handle URL
- response = requests.get(image_path_or_image)
+ response = requests.get(image_path_or_image, timeout=30)
response.raise_for_status()
return Image.open(BytesIO(response.content)).convert("RGB")🧰 Tools
🪛 Ruff (0.14.13)
201-201: Probable use of requests call without timeout
(S113)
🤖 Prompt for AI Agents
In `@nemo_rl/data/multimodal_utils.py` around lines 199 - 203, The requests.get
call that handles URL images (the branch checking
image_path_or_image.startswith(("http://", "https://"))) must include a timeout
to avoid indefinite hangs; update the call in multimodal_utils.py to pass a
timeout (e.g., timeout=10 or a configurable constant like REQUEST_TIMEOUT) to
requests.get(image_path_or_image, timeout=...), and propagate this constant or
parameter where appropriate so the image-loading function (the block using
image_path_or_image and requests.get) uses a bounded timeout for network calls.
| def create_sample_data(input_key, output_key, is_save_to_disk=False): | ||
| data = [ | ||
| {input_key: "Hello", output_key: "Hi there!"}, | ||
| {input_key: "How are you?", output_key: "I'm good, thanks!"}, | ||
| ] | ||
| val_data = [ | ||
| {input_key: "What's up?", output_key: "Not much!"}, | ||
| {input_key: "Bye", output_key: "Goodbye!"}, | ||
| ] | ||
|
|
||
| # Create temporary files for train and validation data | ||
| with tempfile.NamedTemporaryFile( | ||
| mode="w", suffix=".json", delete=False | ||
| ) as train_file: | ||
| json.dump(train_data, train_file) | ||
| train_path = train_file.name | ||
| # Create temporary dataset file | ||
| if is_save_to_disk: | ||
| data_path = tempfile.mktemp() | ||
| dataset = Dataset.from_list(data) | ||
| dataset.save_to_disk(data_path) | ||
| else: | ||
| with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: | ||
| json.dump(data, f) | ||
| data_path = f.name | ||
|
|
||
| with tempfile.NamedTemporaryFile( | ||
| mode="w", suffix=".json", delete=False | ||
| ) as val_file: | ||
| json.dump(val_data, val_file) | ||
| val_path = val_file.name | ||
| return data_path |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
cd tests/unit/data/datasets && ls -laRepository: NVIDIA-NeMo/RL
Length of output: 548
🏁 Script executed:
cat -n tests/unit/data/datasets/test_response_dataset.py | head -50Repository: NVIDIA-NeMo/RL
Length of output: 2135
🏁 Script executed:
python3 -c "import tempfile; import warnings; warnings.simplefilter('always'); tempfile.mktemp()" 2>&1 || trueRepository: NVIDIA-NeMo/RL
Length of output: 40
🏁 Script executed:
python3 << 'EOF'
import tempfile
import inspect
# Check mktemp documentation
print("mktemp docstring:")
print(tempfile.mktemp.__doc__)
print("\n\nmkdtemp docstring:")
print(tempfile.mkdtemp.__doc__)
EOFRepository: NVIDIA-NeMo/RL
Length of output: 971
Replace tempfile.mktemp() with tempfile.mkdtemp().
mktemp() is unsafe and deprecated. It only returns a filename without creating the file or directory, creating a race condition vulnerability. Use mkdtemp() to create an actual temporary directory securely.
Suggested fix
- data_path = tempfile.mktemp()
+ data_path = tempfile.mkdtemp()🧰 Tools
🪛 ast-grep (0.40.5)
[warning] 34-34: The function mktemp is deprecated. When using this function, it is possible for an attacker to modify the created file before the filename is returned. Use NamedTemporaryFile() instead and pass it the delete=False parameter.
Context: tempfile.mktemp()
Note: [CWE-377]: Insecure Temporary File [OWASP A01:2021]: Broken Access Control [REFERENCES]
https://docs.python.org/3/library/tempfile.html#tempfile.mktemp
https://owasp.org/Top10/A01_2021-Broken_Access_Control
(avoid-mktemp-python)
🪛 Ruff (0.14.13)
35-35: Use of insecure and deprecated function (mktemp)
(S306)
🤖 Prompt for AI Agents
In `@tests/unit/data/datasets/test_response_dataset.py` around lines 27 - 43, In
create_sample_data, replace the unsafe tempfile.mktemp() call used to set
data_path with tempfile.mkdtemp() so a real temporary directory is created
before calling Dataset.from_list(...).save_to_disk(data_path); ensure data_path
(the value returned by tempfile.mkdtemp) is passed directly to
dataset.save_to_disk and remove any reliance on an unused filename; keep the
rest of the function (the NamedTemporaryFile branch and returned data_path)
unchanged.
Closes #1050.
Summary by CodeRabbit
New Features
Documentation
Breaking Changes
Configuration Updates
✏️ Tip: You can customize this high-level summary in your review settings.