-
Notifications
You must be signed in to change notification settings - Fork 223
feat: refactor common data utilities of dtensor policy v2 #1710
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
4f66b8f to
81174e5
Compare
0dfbb1e to
2959be7
Compare
|
Signed-off-by: Hemil Desai <hemild@nvidia.com>
2959be7 to
be28e23
Compare
|
|
📝 WalkthroughWalkthroughIntroduces a new data processing module for automodel with dataclasses ( Changes
Sequence Diagram(s)sequenceDiagram
participant Iterator as Raw Iterator
participant MBIter as Microbatch Iterator
participant Proc as process_microbatch()
participant Tokenizer as Tokenizer
participant DPMesh as DP Mesh (all_reduce)
participant Inputs as ProcessedInputs
Iterator->>MBIter: raw BatchedDataDict
loop per item in iterator
MBIter->>Proc: call with BatchedDataDict
Proc->>Tokenizer: tokenize/prepare
alt Sequence Packing Enabled
Proc->>Proc: pack_sequences()
Proc->>Proc: get_flash_attention_kwargs()
end
alt Context Parallel (cp_size > 1)
Proc->>Proc: construct cp_buffers<br/>and seq_index
end
Proc->>Inputs: create ProcessedInputs
MBIter->>MBIter: wrap in ProcessedMicrobatch
MBIter->>Iterator: yield ProcessedMicrobatch
end
Iterator->>DPMesh: process_global_batch()<br/>extracts batch
DPMesh->>DPMesh: all_reduce(global_valid_seqs,<br/>global_valid_toks)
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested labels
Suggested reviewers
🚥 Pre-merge checks | ✅ 2 | ❌ 2❌ Failed checks (2 warnings)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 5
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py (1)
492-503: Avoid leaking metrics from dummy microbatches.
num_valid_samplesis only set for real microbatches; dummy iterations can reuse the prior value and append duplicate metrics. Explicitly zero it (or guard the append) for dummy batches.🐛 Suggested fix
- else: - loss *= 0 + else: + loss *= 0 + num_valid_samples = 0
🤖 Fix all issues with AI agents
In `@nemo_rl/models/automodel/data.py`:
- Around line 1-13: Update the file header year from 2025 to 2026 in the NVIDIA
copyright block at the top of nemo_rl/models/automodel/data.py; locate the
existing copyright comment starting with "# Copyright (c) 2025, NVIDIA
CORPORATION." and change "2025" to "2026" so the header reflects the current
year.
In `@nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py`:
- Around line 639-651: The unpacking of packed logprobs uses input_ids.shape[0]
(which is 1 when packed) and thus allocates/iterates the wrong batch size;
instead read the original batch size from the batch metadata stored in lp_batch
(e.g., lp_batch.original_batch_size or lp_batch['original_batch_size']) and use
that value when allocating unpacked_logprobs and for the unpacking loop/range in
the unpacking logic (the code that creates unpacked_logprobs and iterates over
batch indices). Apply the same change to the other unpacking site referenced
(the block around the unpacking logic in the later section).
- Around line 969-973: In get_topk_logits, skip dummy microbatches by checking
batch_idx against iterator_len inside the loop over processed_iterator—after
obtaining batch_idx and processed_mb (and before using
lp_batch/processed_inputs/input_lengths), add a guard like if batch_idx >=
iterator_len: break (or continue if appropriate) so dummy microbatches do not
produce extra outputs; update any related logic in get_topk_logits to rely on
iterator_len instead of processing all entries from processed_iterator.
- Around line 866-870: In the score() loop inside dtensor_policy_worker_v2.py,
skip any dummy microbatches appended by sequence packing by checking the batch
index against iterator_len; inside the for batch_idx, processed_mb in
enumerate(processed_iterator) loop (which updates step and collects
all_rm_scores), add a guard like if batch_idx >= iterator_len: break or continue
so you do not process or append scores for padded/dummy batches (use the
existing batch_idx, iterator_len, processed_mb, all_rm_scores, step symbols to
locate the change).
In `@tests/unit/models/automodel/test_automodel_data.py`:
- Around line 70-125: Tests contain unused variables/fixtures (e.g.,
mb_iterator, dummy_iterator, mock_tokenizer, *args/**kwargs) that trigger Ruff
ARG001/ARG002/RUF059; update the tests (including functions like
test_dynamic_batching and usages around get_microbatch_iterator,
BatchedDataDict, make_microbatch_iterator_with_dynamic_shapes,
get_microbatch_iterator_dynamic_shapes_len) by renaming intentionally unused
locals/fixtures with a leading underscore (e.g., _mb_iterator, _dummy_iterator,
_mock_tokenizer or _args/_kwargs) or, where truly intentional, annotate the
binding with a trailing "# noqa" to suppress the linter; then re-run Ruff to
confirm all reported warnings (including the other listed test regions) are
resolved.
🧹 Nitpick comments (1)
nemo_rl/models/automodel/data.py (1)
217-220: Track the sequence‑packing workaround.
The TODO indicates a known workaround formin_seq_len; consider linking it to an issue or follow‑up ticket so it doesn’t linger.If you want, I can help draft a follow-up issue or propose a replacement implementation.
| # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Update the NVIDIA copyright year to 2026.
This is a new non-test file, so the header should reflect the current year.
🔧 Suggested fix
-# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
+# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.As per coding guidelines, the NVIDIA header should include the current year.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. |
🤖 Prompt for AI Agents
In `@nemo_rl/models/automodel/data.py` around lines 1 - 13, Update the file header
year from 2025 to 2026 in the NVIDIA copyright block at the top of
nemo_rl/models/automodel/data.py; locate the existing copyright comment starting
with "# Copyright (c) 2025, NVIDIA CORPORATION." and change "2025" to "2026" so
the header reflects the current year.
| lp_batch = processed_mb.data_dict | ||
| processed_inputs = processed_mb.processed_inputs | ||
|
|
||
| # Extract values from processed inputs | ||
| input_ids = processed_inputs.input_ids | ||
| attention_mask = processed_inputs.attention_mask | ||
| position_ids = processed_inputs.position_ids | ||
| flash_attn_kwargs = processed_inputs.flash_attn_kwargs | ||
| vlm_kwargs = processed_inputs.vlm_kwargs | ||
| cp_buffers = processed_inputs.cp_buffers | ||
| seq_index = processed_inputs.seq_index | ||
| seq_len = processed_inputs.seq_len | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use original batch size when unpacking packed logprobs.
With packing, input_ids.shape[0] is 1, so unpacked_logprobs only allocates a single row and drops the rest. Use the original batch size for the unpacking loop/shape.
🐛 Suggested fix
- processed_inputs = processed_mb.processed_inputs
+ processed_inputs = processed_mb.processed_inputs
+ original_batch_size = processed_mb.original_batch_size
...
- unpacked_logprobs = torch.zeros(
- (batch_size, seq_dim_size),
+ unpacked_logprobs = torch.zeros(
+ (original_batch_size, seq_dim_size),
dtype=token_logprobs.dtype,
device=token_logprobs.device,
)
cu_seqlens = flash_attn_kwargs.cu_seqlens_q
- for i in range(batch_size):
+ for i in range(original_batch_size):
start = cu_seqlens[i].item() + 1
end = cu_seqlens[i + 1].item()
seq_len_actual = input_lengths[i].item()Also applies to: 799-817
🤖 Prompt for AI Agents
In `@nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py` around lines 639 -
651, The unpacking of packed logprobs uses input_ids.shape[0] (which is 1 when
packed) and thus allocates/iterates the wrong batch size; instead read the
original batch size from the batch metadata stored in lp_batch (e.g.,
lp_batch.original_batch_size or lp_batch['original_batch_size']) and use that
value when allocating unpacked_logprobs and for the unpacking loop/range in the
unpacking logic (the code that creates unpacked_logprobs and iterates over batch
indices). Apply the same change to the other unpacking site referenced (the
block around the unpacking logic in the later section).
| step = 0 | ||
| all_rm_scores = [] | ||
| for batch_idx, generate_batch in enumerate( | ||
| itertools.chain(mb_iterator, dummy_iterator) | ||
| ): | ||
| for batch_idx, processed_mb in enumerate(processed_iterator): | ||
| step += 1 | ||
| input_ids = generate_batch.get("input_ids").cuda() | ||
| input_lengths = generate_batch.get("input_lengths") | ||
| batch_size, seq_len = input_ids.shape | ||
| if self.enable_seq_packing: | ||
| input_ids, position_ids, _ = pack_sequences( | ||
| input_ids=input_ids, | ||
| input_lengths=input_lengths, | ||
| packed_sequence_size=[ | ||
| batch_size | ||
| ], # flash attention 2 expects flattened input | ||
| padding_value=self.tokenizer.eos_token_id, | ||
| return_attention_mask=False, | ||
| ) | ||
| seq_len = input_ids.shape[1] | ||
| attention_mask = None | ||
| flash_attn_kwargs = get_flash_attention_kwargs( | ||
| input_lengths=input_lengths, | ||
| ) | ||
| else: | ||
| # Create attention mask for right-padded data | ||
| post_attention_mask = torch.zeros( | ||
| (batch_size, seq_len), dtype=torch.bool, device=input_ids.device | ||
| ) | ||
| for i, length in enumerate(input_lengths): | ||
| # For right-padded sequence, set 1s at the beginning of the sequence | ||
| post_attention_mask[i, :length] = 1 | ||
| position_ids = torch.arange( | ||
| seq_len, device=input_ids.device | ||
| ).repeat(batch_size, 1) | ||
|
|
||
| attention_mask = torch.ones( | ||
| (batch_size, seq_len), | ||
| dtype=torch.bool, | ||
| device=input_ids.device, | ||
| ) | ||
| if self.cp_size > 1: | ||
| seq_index = torch.arange(seq_len, device=input_ids.device).repeat( | ||
| 1, 1 | ||
| ) | ||
| cp_buffers = [input_ids, position_ids, seq_index] | ||
| else: | ||
| cp_buffers = [] | ||
| seq_index = None | ||
| # Extract processed inputs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Skip dummy microbatches in score().
When sequence packing yields uneven batch counts, dummy batches are appended and currently produce extra scores. Break/continue once batch_idx >= iterator_len.
🐛 Suggested fix
- for batch_idx, processed_mb in enumerate(processed_iterator):
+ for batch_idx, processed_mb in enumerate(processed_iterator):
+ if batch_idx >= iterator_len:
+ break🧰 Tools
🪛 Ruff (0.14.13)
868-868: Loop control variable batch_idx not used within loop body
Rename unused batch_idx to _batch_idx
(B007)
🤖 Prompt for AI Agents
In `@nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py` around lines 866 -
870, In the score() loop inside dtensor_policy_worker_v2.py, skip any dummy
microbatches appended by sequence packing by checking the batch index against
iterator_len; inside the for batch_idx, processed_mb in
enumerate(processed_iterator) loop (which updates step and collects
all_rm_scores), add a guard like if batch_idx >= iterator_len: break or continue
so you do not process or append scores for padded/dummy batches (use the
existing batch_idx, iterator_len, processed_mb, all_rm_scores, step symbols to
locate the change).
| for batch_idx, processed_mb in enumerate(processed_iterator): | ||
| # Extract data dict and processed inputs | ||
| lp_batch = processed_mb.data_dict | ||
| processed_inputs = processed_mb.processed_inputs | ||
| input_lengths = lp_batch.get("input_lengths") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Skip dummy microbatches in get_topk_logits().
Dummy batches will otherwise add extra outputs. Break/continue once batch_idx >= iterator_len.
🐛 Suggested fix
- for batch_idx, processed_mb in enumerate(processed_iterator):
+ for batch_idx, processed_mb in enumerate(processed_iterator):
+ if batch_idx >= iterator_len:
+ break🧰 Tools
🪛 Ruff (0.14.13)
969-969: Loop control variable batch_idx not used within loop body
Rename unused batch_idx to _batch_idx
(B007)
🤖 Prompt for AI Agents
In `@nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py` around lines 969 -
973, In get_topk_logits, skip dummy microbatches by checking batch_idx against
iterator_len inside the loop over processed_iterator—after obtaining batch_idx
and processed_mb (and before using lp_batch/processed_inputs/input_lengths), add
a guard like if batch_idx >= iterator_len: break (or continue if appropriate) so
dummy microbatches do not produce extra outputs; update any related logic in
get_topk_logits to rely on iterator_len instead of processing all entries from
processed_iterator.
| mb_iterator, iterator_len, dummy_iterator = get_microbatch_iterator( | ||
| data=data, | ||
| cfg=cfg, | ||
| enable_seq_packing=enable_seq_packing, | ||
| mbs=mbs, | ||
| dp_mesh=mock_dp_mesh, | ||
| ) | ||
|
|
||
| # Verify iterator length | ||
| assert iterator_len == 4 # 16 / 4 = 4 | ||
|
|
||
| # Verify we can iterate through the data | ||
| batches = list(mb_iterator) | ||
| assert len(batches) == 4 | ||
| assert batches[0]["input_ids"].shape[0] == 4 | ||
|
|
||
| # Verify dummy iterator is empty | ||
| assert list(dummy_iterator) == [] | ||
|
|
||
| def test_dynamic_batching(self): | ||
| # Create test data | ||
| data = BatchedDataDict( | ||
| { | ||
| "input_ids": torch.randint(0, 1000, (8, 128)), | ||
| "sample_mask": torch.ones(8, dtype=torch.bool), | ||
| } | ||
| ) | ||
|
|
||
| # Mock the microbatch iterator methods | ||
| mock_iterator = iter([data, data, data]) | ||
| data.make_microbatch_iterator_with_dynamic_shapes = MagicMock( | ||
| return_value=mock_iterator | ||
| ) | ||
| data.get_microbatch_iterator_dynamic_shapes_len = MagicMock(return_value=3) | ||
|
|
||
| cfg = {"dynamic_batching": {"enabled": True}} | ||
| mbs = 4 | ||
| enable_seq_packing = False | ||
| mock_dp_mesh = MagicMock() | ||
|
|
||
| mb_iterator, iterator_len, dummy_iterator = get_microbatch_iterator( | ||
| data=data, | ||
| cfg=cfg, | ||
| enable_seq_packing=enable_seq_packing, | ||
| mbs=mbs, | ||
| dp_mesh=mock_dp_mesh, | ||
| ) | ||
|
|
||
| # Verify dynamic batching was used | ||
| assert iterator_len == 3 | ||
| data.make_microbatch_iterator_with_dynamic_shapes.assert_called_once() | ||
| data.get_microbatch_iterator_dynamic_shapes_len.assert_called_once() | ||
|
|
||
| # Verify dummy iterator is empty | ||
| assert list(dummy_iterator) == [] | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
# First, let's explore the repository structure and locate the file
find . -name "test_automodel_data.py" -type fRepository: NVIDIA-NeMo/RL
Length of output: 110
🏁 Script executed:
# Check if Ruff is available and what version
ruff --versionRepository: NVIDIA-NeMo/RL
Length of output: 70
🏁 Script executed:
# Let's read the test file to understand the context
head -200 tests/unit/models/automodel/test_automodel_data.py | cat -nRepository: NVIDIA-NeMo/RL
Length of output: 7985
🏁 Script executed:
# Read lines 150-160 to see the context around lines 152-155
sed -n '145,165p' tests/unit/models/automodel/test_automodel_data.py | cat -nRepository: NVIDIA-NeMo/RL
Length of output: 859
🏁 Script executed:
# Read around lines 448-478
sed -n '440,485p' tests/unit/models/automodel/test_automodel_data.py | cat -nRepository: NVIDIA-NeMo/RL
Length of output: 2042
🏁 Script executed:
# Read around lines 967-969
sed -n '960,975p' tests/unit/models/automodel/test_automodel_data.py | cat -nRepository: NVIDIA-NeMo/RL
Length of output: 600
🏁 Script executed:
# Now run Ruff on the file to see the actual warnings
ruff check tests/unit/models/automodel/test_automodel_data.py --select=ARG001,ARG002,RUF059Repository: NVIDIA-NeMo/RL
Length of output: 11071
🏁 Script executed:
# Verify the line 967-969 reference mentioned in the review
sed -n '960,975p' tests/unit/models/automodel/test_automodel_data.py | cat -nRepository: NVIDIA-NeMo/RL
Length of output: 600
🏁 Script executed:
# Confirm the exact line numbers for lines 70 and 110 (test_regular_batching vs test_dynamic_batching)
sed -n '56,90p' tests/unit/models/automodel/test_automodel_data.py | cat -nRepository: NVIDIA-NeMo/RL
Length of output: 1347
🏁 Script executed:
# Get a better view of the exact errors and their line numbers for comprehensive coverage
ruff check tests/unit/models/automodel/test_automodel_data.py --select=ARG001,ARG002,RUF059 -v 2>&1 | grep -E "^|help:" | head -50Repository: NVIDIA-NeMo/RL
Length of output: 3053
🏁 Script executed:
# Let's get the total count and grouping of errors
ruff check tests/unit/models/automodel/test_automodel_data.py --select=ARG001,ARG002,RUF059 2>&1 | tail -1Repository: NVIDIA-NeMo/RL
Length of output: 142
Fix Ruff linter warnings for unused variables and parameters.
ARG001/ARG002/RUF059 will trigger on unused fixtures/locals (e.g., mb_iterator, mock_tokenizer, *args/**kwargs). Prefix with _ or add # noqa where intentional.
♻️ Example cleanup (apply similarly elsewhere)
- mb_iterator, iterator_len, dummy_iterator = get_microbatch_iterator(
+ _mb_iterator, iterator_len, dummy_iterator = get_microbatch_iterator(
data=data,
cfg=cfg,
enable_seq_packing=enable_seq_packing,
mbs=mbs,
dp_mesh=mock_dp_mesh,
)
...
- def side_effect(tensor, *args, **kwargs):
+ def side_effect(tensor, *_args, **_kwargs):
tensor[0] = 2 # Simulate max batch count
...
- def test_processed_microbatch_creation(self, mock_tokenizer):
+ def test_processed_microbatch_creation(self, _mock_tokenizer):
"""Test that ProcessedMicrobatch correctly stores all attributes."""Also applies to: 152-155, 448-478, 858, 742, 796, 841, 913, 985.
Please re-run Ruff after cleanup to confirm the warnings are resolved.
🧰 Tools
🪛 Ruff (0.14.13)
110-110: Unpacked variable mb_iterator is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
🤖 Prompt for AI Agents
In `@tests/unit/models/automodel/test_automodel_data.py` around lines 70 - 125,
Tests contain unused variables/fixtures (e.g., mb_iterator, dummy_iterator,
mock_tokenizer, *args/**kwargs) that trigger Ruff ARG001/ARG002/RUF059; update
the tests (including functions like test_dynamic_batching and usages around
get_microbatch_iterator, BatchedDataDict,
make_microbatch_iterator_with_dynamic_shapes,
get_microbatch_iterator_dynamic_shapes_len) by renaming intentionally unused
locals/fixtures with a leading underscore (e.g., _mb_iterator, _dummy_iterator,
_mock_tokenizer or _args/_kwargs) or, where truly intentional, annotate the
binding with a trailing "# noqa" to suppress the linter; then re-run Ruff to
confirm all reported warnings (including the other listed test regions) are
resolved.
|
| ) | ||
|
|
||
|
|
||
| def get_microbatch_iterator( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the mcore data refactor PR, get_microbatch_iterator calls make_processed_microbatch_iterator directly, so the processed iterator is returned from this fn. Is that something we want to do here too? I made that decision because there's no instance where we'd want to call get_microbatch_iterator without make_processed_microbatch_iterator
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok yeah I can make that change here too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
|
terrykong
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks good to me, small comments. like the first PR can you run the dtensor v2 nightlies to make sure this change hasn't broken anything
| # For right-padded sequence, set 1s at the beginning of the sequence | ||
| post_attention_mask[i, :length] = 1 | ||
|
|
||
| # explicitly create position ids for the input, otherwise the sharding |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this comment was dropped i think, could we add it back if needed?
another thing that i was reminded of @akoumpa mentioned something about position_ids being an issue if we're using sequence parallel. something about the sharding being incorrect but I don't recall the exact context and whether the guidance was to pass it or not. @hemildesai @akoumpa any idea on this?
| ).repeat(batch_size, 1) | ||
| flash_attn_kwargs = {} | ||
|
|
||
| # DTensor requires the casual attention kernel to hit, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
another comment worth keeping?
| device=input_ids.device, | ||
| ) | ||
|
|
||
| # if there are multimodal kwargs, we don't need to add position_ids (computed internally) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
comment worth keeping?
| def make_processed_microbatch_iterator( | ||
| raw_iterator: Iterator[BatchedDataDict[Any]], | ||
| tokenizer: AutoTokenizer, | ||
| enable_seq_packing: bool, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
to mirror mcore, wdyt about just inferring this from the config like the mcore PR is?
cc @ashors1
Depends on #1709
Issues
#1589
Usage
# Add a code snippet demonstrating how to use thisBefore your PR is "Ready for review"
Pre checks:
Additional Information
Summary by CodeRabbit
New Features
Tests
Chores
✏️ Tip: You can customize this high-level summary in your review settings.