Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions fast_llm/data/document/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ class LanguageModelTargetInput(ModelInput):
class LanguageModelInput(TokenModelInput):
targets: list[LanguageModelTargetInput] = dataclasses.field(default_factory=list)
image_patches: PatchModelInput | None = None
# Number of documents with at least one response token in this micro-batch.
# Computed before any TP/SP/PP splitting; used to normalize per-document metrics.
num_docs: int | None = None

def set_children_attributes(self) -> None:
if self.image_patches is not None:
Expand All @@ -58,6 +61,7 @@ def to_kwargs(self) -> dict[str, typing.Any]:
LanguageModelKwargs.advantages: [target.advantages for target in self.targets],
LanguageModelKwargs.old_log_probabilities: [target.old_log_probabilities for target in self.targets],
LanguageModelKwargs.label_counts: [target.label_counts for target in self.targets],
LanguageModelKwargs.num_docs: self.num_docs,
}
if self.image_patches is not None:
out.update(self.image_patches.to_kwargs())
Expand Down Expand Up @@ -121,6 +125,12 @@ def get_model_inputs(self, config: LanguageModelBatchPreprocessingConfig) -> lis

model_inputs.append(model_input)

# num_docs is counted only on the first split (micro_sequence_index==0) to avoid
# double-counting documents that span a split boundary when micro_batch_splits > 1.
# The first split already has the correct count (or 0 for SDP rank > 0); clear the rest.
for model_input in model_inputs[1:]:
model_input.num_docs = None

return model_inputs

def _get_model_input(
Expand Down Expand Up @@ -161,6 +171,17 @@ def _get_model_input(
length_cumsum = torch.tensor([0] + cropped_lengths, device=self.device).cumsum(0)
label_count_cumsum = mask_cumsum[length_cumsum]
labels_per_document = label_count_cumsum[1:] - label_count_cumsum[:-1]
# Track documents with at least one response token for per-document metric
# normalization. Only counted on sequence_data_rank==0 so that when the runner
# all_reduces the denominator across the data group (which includes SDP ranks),
# each document is counted exactly once even though all SDP ranks see the same
# documents (but process different token slices of them).
if model_input.num_docs is None:
model_input.num_docs = (
int((labels_per_document > 0).sum().item())
if config.distributed.sequence_data_rank == 0
else 0
)
# Expand to one entry per token: find each token's document index via the sorted
# length cumsum, then look up that document's label count.
# TODO: Document index already computed in `LengthModelInputPreprocessor`.
Expand Down
5 changes: 4 additions & 1 deletion fast_llm/engine/base_model/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,9 @@ class LossDef:
name: str
formatted_name: str
# The number of times this loss is evaluated by the model for each micro-batch. Used as a denominator for averaging.
# TODO: Allow variable count? Would need a reduction across PP devices.
count: int = 1
dtype: DataType = DataType.float32
# If set, normalize this metric by summing values from context.batch[i][denominator_batch_field]
# across micro-batches and DP ranks, instead of using count * data_parallel * num_inputs.
# The field must be a scalar (int or float) pre-computed before any TP/SP/PP splitting.
denominator_batch_field: str | None = None
41 changes: 30 additions & 11 deletions fast_llm/engine/schedule/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,20 +287,39 @@ def run_step(
def _reduce_losses(self, context: BatchContext) -> dict[str, float | int]:
reduced_losses = {}
for name, losses in context.losses.items():
loss_def = self._loss_definitions[name]
if losses or self._distributed.pipeline_group:
if losses:
loss_count = (
self._loss_definitions[name].count
* self._distributed_config.data_parallel
* context.schedule.config.num_inputs
)
reduced_loss = torch.stack(losses).sum() / loss_count
if self._distributed.data_group:
all_reduce(reduced_loss, group=self._distributed.data_group)
denom_field = loss_def.denominator_batch_field
if denom_field is not None:
# Normalize by a per-micro-batch scalar from the batch data (e.g. num_docs),
# computed before any TP/SP/PP splitting. Sum numerator and denominator
# independently across DP ranks so the result is a true global average.
numerator = torch.stack(losses).sum()
denominator = torch.tensor(
sum(
batch_kwargs[denom_field] or 0
for batch_kwargs in context.batch.values()
if denom_field in batch_kwargs
),
dtype=numerator.dtype,
device=numerator.device,
)
if self._distributed.data_group:
all_reduce(numerator, group=self._distributed.data_group)
all_reduce(denominator, group=self._distributed.data_group)
reduced_loss = numerator / denominator.clamp(min=1)
else:
loss_count = (
loss_def.count
* self._distributed_config.data_parallel
* context.schedule.config.num_inputs
)
reduced_loss = torch.stack(losses).sum() / loss_count
if self._distributed.data_group:
all_reduce(reduced_loss, group=self._distributed.data_group)
else:
reduced_loss = torch.zeros(
[1], dtype=self._loss_definitions[name].dtype.torch, device=self._distributed.device
)
reduced_loss = torch.zeros([1], dtype=loss_def.dtype.torch, device=self._distributed.device)
if self._distributed.pipeline_group:
all_reduce(reduced_loss, group=self._distributed.pipeline_group)
else:
Expand Down
1 change: 1 addition & 0 deletions fast_llm/layers/language_model/loss/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class LanguageModelLossKwargs(BlockKwargs):
advantages = "advantages"
old_log_probabilities = "old_log_probabilities"
label_counts = "num_labels_in_seq"
num_docs = "num_docs"


@config_class(registry=True)
Expand Down
6 changes: 5 additions & 1 deletion fast_llm/layers/language_model/loss/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,12 @@ def get_loss_definitions(self, count: int = 1) -> list[LossDef]:
LossDef(
self._logprob_metric_name,
formatted_name=self._logprob_metric_name,
count=1, # This is an additive metric over the sequence.
count=1,
dtype=DataType.float32,
# Normalize by the number of documents with response tokens in each micro-batch,
# giving a true per-document average regardless of variable document lengths.
# num_docs is computed before any TP/SP/PP splitting in language_model.py.
denominator_batch_field=LanguageModelLossKwargs.num_docs,
)
]

Expand Down
162 changes: 162 additions & 0 deletions tests/data/test_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from fast_llm.data.document.config import LanguageModelBatchPreprocessingConfig
from fast_llm.data.document.language_model import LanguageModelBatch, LanguageModelDocument
from fast_llm.data.document.range import RangeDocument
from fast_llm.data.document.token_data import TokenDataDocument
from fast_llm.engine.distributed.config import DistributedConfig
from fast_llm.utils import Assert


Expand Down Expand Up @@ -53,3 +55,163 @@ def test_preprocessing(tokens, loss_masking_spans):

Assert.eq(len(model_input.targets), 1)
Assert.all_equal(model_input.targets[0].tokens, torch.cat(label_tokens)[1:])


def _make_grpo_document(tokens, loss_masking_spans=None):
"""Helper: create a LanguageModelDocument with GRPO fields (advantages, old_log_probabilities)."""
t = torch.tensor(tokens, dtype=torch.int64)
n = len(t)
return LanguageModelDocument(
tokens=t,
loss_masking_spans=None if loss_masking_spans is None else RangeDocument(ranges=loss_masking_spans),
advantages=TokenDataDocument(data=torch.zeros(n)),
old_log_probabilities=TokenDataDocument(data=torch.zeros(n)),
)


@pytest.mark.parametrize(
("token_lists", "loss_masking_spans_list", "expected_num_docs"),
(
# Single doc, no masking β€” all tokens are response tokens except first (cross-doc mask)
([[1, 2, 3, 4, 5]], [None], 1),
# Single doc fully masked by loss_masking_spans β€” no response tokens, num_docs = 0
([[1, 2, 3, 4, 5]], [[(0, 5)]], 0),
# Two docs, both with response tokens
([[1, 2, 3], [4, 5, 6]], [None, None], 2),
# Two docs, one fully masked β€” only 1 contributes
([[1, 2, 3], [4, 5, 6]], [[(0, 3)], None], 1),
# Two docs, both fully masked
([[1, 2, 3], [4, 5, 6]], [[(0, 3)], [(0, 3)]], 0),
# Padding: a short doc packed into a larger micro_batch_size leaves a padding segment
([[1, 2, 3]], [None], 1), # with pad_to_size below
),
)
def test_num_docs_computation(token_lists, loss_masking_spans_list, expected_num_docs):
"""num_docs counts only documents that have at least one non-masked response token."""
documents = [_make_grpo_document(tokens, spans) for tokens, spans in zip(token_lists, loss_masking_spans_list)]
config = LanguageModelBatchPreprocessingConfig(use_grpo_data=True, return_label_counts=True)
(model_input,) = LanguageModelBatch.from_documents(documents).get_model_inputs(config)
Assert.eq(model_input.num_docs, expected_num_docs)


def test_num_docs_excludes_padding():
"""Padding appended by pad_to_size is a 0-label segment and must not count toward num_docs."""
documents = [_make_grpo_document([1, 2, 3, 4])]
config = LanguageModelBatchPreprocessingConfig(use_grpo_data=True, return_label_counts=True)
# pad_to_size > total tokens forces a padding segment to be added
(model_input,) = LanguageModelBatch.from_documents(documents, pad_to_size=10).get_model_inputs(config)
# Only the real document counts; the padding segment (all -100) does not
Assert.eq(model_input.num_docs, 1)


def test_num_docs_none_without_label_counts():
"""num_docs is None when return_label_counts is False (GRPO preprocessing not requested)."""
documents = [_make_grpo_document([1, 2, 3, 4])]
config = LanguageModelBatchPreprocessingConfig()
(model_input,) = LanguageModelBatch.from_documents(documents).get_model_inputs(config)
assert model_input.num_docs is None


def _make_sdp_config(sdp_rank: int, sdp_size: int = 2) -> LanguageModelBatchPreprocessingConfig:
"""Config simulating a given sequence-data-parallel rank."""
return LanguageModelBatchPreprocessingConfig(
use_grpo_data=True,
return_label_counts=True,
distributed=DistributedConfig(world_size=sdp_size, rank=sdp_rank, sequence_data_parallel=sdp_size),
)


def test_num_docs_sdp_only_counted_on_rank0():
"""With SDP=2, num_docs is counted only on sequence_data_rank=0.

The runner all_reduces the denominator across the data group (which includes all SDP
ranks). If both SDP ranks reported num_docs=1 for the same document, the all_reduce
SUM would produce denominator=2 and halve the metric. Only rank 0 must contribute
to avoid this double-counting.
"""
# 9 tokens β†’ total_input_length = 8 (divisible by SDP=2)
documents = [_make_grpo_document([1, 2, 3, 4, 5, 6, 7, 8, 9])]
batch = LanguageModelBatch.from_documents(documents)

(model_input_rank0,) = batch.get_model_inputs(_make_sdp_config(sdp_rank=0))
(model_input_rank1,) = batch.get_model_inputs(_make_sdp_config(sdp_rank=1))

# Rank 0 counts the document; rank 1 must not.
Assert.eq(model_input_rank0.num_docs, 1)
Assert.eq(model_input_rank1.num_docs, 0)

# After all_reduce SUM across SDP ranks the denominator equals the true doc count.
Assert.eq(model_input_rank0.num_docs + model_input_rank1.num_docs, 1)


def test_num_docs_sdp_fully_masked_excluded_on_rank0():
"""A fully-masked document is excluded from num_docs even on SDP rank 0."""
# Doc 0 fully masked; doc 1 has response tokens. 9 tokens total β†’ 8 input tokens (div by 2).
documents = [
_make_grpo_document([1, 2, 3, 4], loss_masking_spans=[(0, 4)]),
_make_grpo_document([5, 6, 7, 8, 9]),
]
batch = LanguageModelBatch.from_documents(documents)

(model_input_rank0,) = batch.get_model_inputs(_make_sdp_config(sdp_rank=0))
(model_input_rank1,) = batch.get_model_inputs(_make_sdp_config(sdp_rank=1))

# Only the unmasked document counts; rank 1 always contributes 0.
Assert.eq(model_input_rank0.num_docs, 1)
Assert.eq(model_input_rank1.num_docs, 0)
Assert.eq(model_input_rank0.num_docs + model_input_rank1.num_docs, 1)


def test_num_docs_sdp_two_docs_counted_once():
"""Two documents on SDP=2 are counted once in total (not once per SDP rank)."""
# 9 tokens total β†’ 8 input tokens, divisible by SDP=2.
documents = [_make_grpo_document([1, 2, 3, 4]), _make_grpo_document([5, 6, 7, 8, 9])]
batch = LanguageModelBatch.from_documents(documents)

(model_input_rank0,) = batch.get_model_inputs(_make_sdp_config(sdp_rank=0))
(model_input_rank1,) = batch.get_model_inputs(_make_sdp_config(sdp_rank=1))

Assert.eq(model_input_rank0.num_docs, 2)
Assert.eq(model_input_rank1.num_docs, 0)
Assert.eq(model_input_rank0.num_docs + model_input_rank1.num_docs, 2)


def test_num_docs_micro_batch_splits_only_first_split_counts():
"""With micro_batch_splits=2, num_docs is non-None only on the first split.

A document that spans the split boundary would be visible in both splits'
cropped_lengths, so both would count it without this guard. The runner sums
num_docs across all splits in context.batch; only the first split must contribute
to avoid double-counting.
"""
# 9 tokens β†’ 8 input tokens, divisible by micro_batch_splits=2 (4 each)
documents = [_make_grpo_document([1, 2, 3, 4, 5, 6, 7, 8, 9])]
config = LanguageModelBatchPreprocessingConfig(
use_grpo_data=True,
return_label_counts=True,
micro_batch_splits=2,
)
split0, split1 = LanguageModelBatch.from_documents(documents).get_model_inputs(config)

# Only the first split carries the count; the second must be None (runner treats as 0).
Assert.eq(split0.num_docs, 1)
assert split1.num_docs is None

# Simulated runner sum: 1 + (None→0) = 1, the correct denominator.
Assert.eq((split0.num_docs or 0) + (split1.num_docs or 0), 1)


def test_num_docs_micro_batch_splits_two_docs():
"""With micro_batch_splits=2 and two documents, only the first split counts both docs."""
# 9 tokens total β†’ 8 input tokens, divisible by 2
documents = [_make_grpo_document([1, 2, 3, 4]), _make_grpo_document([5, 6, 7, 8, 9])]
config = LanguageModelBatchPreprocessingConfig(
use_grpo_data=True,
return_label_counts=True,
micro_batch_splits=2,
)
split0, split1 = LanguageModelBatch.from_documents(documents).get_model_inputs(config)

Assert.eq(split0.num_docs, 2)
assert split1.num_docs is None
Assert.eq((split0.num_docs or 0) + (split1.num_docs or 0), 2)
Loading