Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions swift/llm/argument/train_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,11 +206,23 @@ class TrainArguments(SwanlabArguments, TunerArguments, BaseArguments, Seq2SeqTra
# early_step
early_stop_interval: Optional[int] = None

# dataset progress tracking
track_dataset_progress: bool = False

# greedy packing (non-streaming alternative to binpacking)
greedy_packing: bool = False # Use greedy packing instead of binpacking to avoid preprocessing

def _check_padding_free(self):
# greedy_packing requires padding_free (like packing)
if self.greedy_packing:
self.padding_free = True

if self.padding_free or self.packing:
if self.packing:
feature = 'packing'
self.padding_free = True
elif self.greedy_packing:
feature = 'greedy_packing'
else:
feature = 'padding_free'
if self.attn_impl not in {'flash_attn', 'flash_attention_2', 'flash_attention_3'}:
Expand Down
1 change: 1 addition & 0 deletions swift/llm/dataset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .preprocessor import (AlpacaPreprocessor, AutoPreprocessor, MessagesPreprocessor, ResponsePreprocessor,
RowPreprocessor)
from .register import DATASET_MAPPING, DatasetMeta, SubsetDataset, register_dataset, register_dataset_info
from .collator import ProgressTrackingCollator
from .utils import (AddLengthPreprocessor, EncodePreprocessor, IterablePackingDataset, LazyLLMDataset, PackingDataset,
sample_dataset)

Expand Down
89 changes: 89 additions & 0 deletions swift/llm/dataset/collator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
"""Dataset progress tracking collator wrapper.

This module provides a wrapper collator that extracts dataset source information
for progress tracking during training.
"""
from typing import Any, Callable, Dict, List, Optional, Tuple


class ProgressTrackingCollator:
"""Wrapper collator that extracts dataset sources and token lengths for progress tracking.

This wrapper intercepts the collator output, extracts _dataset_source and length fields
from each sample, and passes them through to the main process via _batch_sources and
_batch_lengths fields for statistics collection.

This approach is non-invasive - it doesn't modify any template code, only
wraps the collator at the training level.

Args:
collator: The original collator function to wrap.
track_progress: Whether to track progress. If False, just removes
_dataset_source field without collecting statistics.

Example:
>>> original_collator = partial(template.data_collator, padding_to=None)
>>> wrapped = ProgressTrackingCollator(original_collator)
>>> batch_result = wrapped(batch) # Contains _batch_sources and _batch_lengths fields
"""

def __init__(self, collator: Callable, track_progress: bool = True):
self.collator = collator
self.track_progress = track_progress

def _extract_info(self, item: Any) -> Tuple[Optional[Any], Optional[int]]:
"""Extract and remove _dataset_source, extract length from item."""
if isinstance(item, dict):
sources = item.pop('_dataset_source', None)
length = item.get('length')
return sources, length
return None, None

def _collect_sources_and_lengths(
self,
sources: Optional[Any],
length: Optional[int],
batch_sources: List[str],
batch_lengths: List[int],
) -> None:
"""Collect sources and lengths into batch lists."""
if self.track_progress and sources:
if isinstance(sources, str):
batch_sources.append(sources)
elif isinstance(sources, list):
batch_sources.extend(sources)
if length is not None:
batch_lengths.append(length)

def __call__(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]:
"""Process batch and extract dataset sources and token lengths.

Args:
batch: List of encoded samples, each may contain _dataset_source and length fields.

Returns:
Collated batch dict with optional _batch_sources and _batch_lengths fields.
"""
# 1. Collect sources and lengths before calling original collator
# (original collator may modify batch in place)
batch_sources: List[str] = []
batch_lengths: List[int] = []

for b in batch:
# Handle both Packing scenario (list) and normal scenario (dict)
items = b if isinstance(b, list) else [b]
for item in items:
sources, length = self._extract_info(item)
self._collect_sources_and_lengths(sources, length, batch_sources, batch_lengths)

# 2. Call original collator
result = self.collator(batch)

# 3. Attach sources and lengths for main process to collect
if batch_sources:
result['_batch_sources'] = batch_sources
if batch_lengths:
result['_batch_lengths'] = batch_lengths

return result
66 changes: 66 additions & 0 deletions swift/llm/dataset/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,15 @@ def _get_matched_dataset_meta(self, dataset_meta_mapping):
return dataset_meta


class _AddDatasetSource:
def __init__(self, source):
self.source = source

def __call__(self, example):
example['_dataset_source'] = self.source
return example


class DatasetLoader:

@staticmethod
Expand Down Expand Up @@ -193,6 +202,16 @@ def _interleave_datasets(datasets, *args, **kwargs):
return datasets[0]
return interleave_datasets(datasets, *args, **kwargs)

@staticmethod
def _add_dataset_source(dataset, source: str, streaming: bool = False):
"""Add _dataset_source column for progress tracking."""
if streaming:
# For IterableDataset, add source via map
return dataset.map(_AddDatasetSource(source))
else:
# For regular Dataset, add column directly
return dataset.add_column('_dataset_source', [source] * len(dataset))

@staticmethod
def _load_dataset_path(
dataset_path: str,
Expand Down Expand Up @@ -519,6 +538,13 @@ def load_dataset(
use_hf_default = use_hf
if use_hf_default is None:
use_hf_default = True if use_hf_hub() else False

# Track original dataset sizes before mixing/resampling for accurate progress tracking
# This enables tracking of training progress relative to original dataset size,
# which is essential when using interleave_datasets with resampling strategies
original_train_dataset_sizes: Dict[str, int] = {}
original_val_dataset_sizes: Dict[str, int] = {}

for dataset in datasets:
dataset_syntax = DatasetSyntax.parse(dataset)
use_hf = dataset_syntax.use_hf or use_hf_default
Expand All @@ -542,9 +568,34 @@ def load_dataset(
shuffle=shuffle,
random_state=seed,
)
# Inject dataset source identifier for progress tracking
dataset_source = dataset_syntax.get_raw()
if train_dataset is not None:
# Record original dataset size before any mixing/resampling operations
# This size represents the actual number of unique samples in the dataset,
# which is crucial for calculating meaningful training progress (epochs)
if not streaming and hasattr(train_dataset, '__len__'):
try:
original_size = len(train_dataset)
if original_size > 0:
original_train_dataset_sizes[dataset_source] = original_size
except Exception as e:
logger.warning(f'Failed to get length of dataset {dataset_source}: {e}')

train_dataset = DatasetLoader._add_dataset_source(train_dataset, dataset_source, streaming)
train_datasets.append(train_dataset)

if val_dataset is not None:
# Record original validation dataset size
if not streaming and hasattr(val_dataset, '__len__'):
try:
original_size = len(val_dataset)
if original_size > 0:
original_val_dataset_sizes[dataset_source] = original_size
except Exception as e:
logger.warning(f'Failed to get length of val_dataset {dataset_source}: {e}')

val_dataset = DatasetLoader._add_dataset_source(val_dataset, dataset_source, streaming)
val_datasets.append(val_dataset)

if interleave_prob is None:
Expand All @@ -563,4 +614,19 @@ def load_dataset(
if val_datasets:
val_datasets = DatasetLoader.shuffle_dataset(
val_datasets, seed=get_seed(seed), buffer_size=shuffle_buffer_size)

# Attach original dataset sizes as metadata for accurate progress tracking
# When using interleave_datasets with resampling (e.g., all_exhausted strategy),
# the mixed dataset may contain duplicated samples. The original sizes allow
# progress callbacks to report training progress relative to unique sample counts,
# which provides more meaningful epoch-based progress metrics.
if original_train_dataset_sizes and train_datasets is not None:
train_datasets._original_dataset_sizes = original_train_dataset_sizes
if logger.isEnabledFor(20): # INFO level
size_summary = ', '.join([f'{k.split("/")[-1]}: {v}' for k, v in original_train_dataset_sizes.items()])
logger.info(f'Attached original dataset sizes for progress tracking: {size_summary}')

if original_val_dataset_sizes and val_datasets is not None:
val_datasets._original_dataset_sizes = original_val_dataset_sizes

return train_datasets, val_datasets
90 changes: 89 additions & 1 deletion swift/llm/dataset/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,11 @@ def __getitem__(self, idx: int) -> Dict[str, Any]:
self._idx = (self._idx + 1) % len(self.dataset)
data = self.dataset[i]
try:
return self.encode_func(data, return_length=True)
result = self.encode_func(data, return_length=True)
# 保留 _dataset_source 用于 greedy_packing 进度统计
if '_dataset_source' in data:
result['_dataset_source'] = data['_dataset_source']
return result
except Exception:
if n_try == self.n_try_fetch - 1 or self.strict:
if self.strict:
Expand Down Expand Up @@ -326,3 +330,87 @@ def preprocess(self, row: Dict[str, Any]) -> Optional[Dict[str, Any]]:
encoded = super().preprocess(row)
row['length'] = encoded['length']
return row


class GreedyPackingDataLoader:
"""在 DataLoader 之上包装贪心打包层。

特点:
- 复用 DataLoader 的多 workers 和 prefetch
- 贪心打包 O(1) 开销
- encode 只执行一次

适用场景:
- 非 streaming 数据集
- 希望避免 binpacking 预处理开销
- 可接受稍低的空间利用率(~85-90% vs ~95%)
"""

def __init__(
self,
dataloader,
packing_length: int,
packing_collator: Callable,
):
"""
Args:
dataloader: 原始 DataLoader,batch_size 应设为 1
packing_length: 打包目标长度
packing_collator: 打包样本的 collator 函数
"""
self.dataloader = dataloader
self.packing_length = packing_length
self.packing_collator = packing_collator
self._length = None

def __iter__(self):
buffer = []
buffer_length = 0

for batch in self.dataloader:
samples = batch if isinstance(batch, list) else [batch]

for sample in samples:
sample_length = sample.get('length') or len(sample['input_ids'])

if buffer_length + sample_length <= self.packing_length:
buffer.append(sample)
buffer_length += sample_length
else:
# 当前包满了,输出
if buffer:
# 注意:必须在调用 packing_collator 之前收集 sources 和 lengths,
# 因为 packing_collator 会原地修改 buffer(batch[:] = [packing_row(batch)])
batch_lengths = [len(s['input_ids']) for s in buffer]
sources = [s.get('_dataset_source') for s in buffer if s.get('_dataset_source')]
packed = self.packing_collator(buffer)
packed['_batch_lengths'] = batch_lengths
if sources:
packed['_batch_sources'] = sources
yield packed

buffer = [sample]
buffer_length = sample_length

# 输出最后一个 pack
if buffer:
# 注意:必须在调用 packing_collator 之前收集 sources 和 lengths
batch_lengths = [len(s['input_ids']) for s in buffer]
sources = [s.get('_dataset_source') for s in buffer if s.get('_dataset_source')]
packed = self.packing_collator(buffer)
packed['_batch_lengths'] = batch_lengths
if sources:
packed['_batch_sources'] = sources
yield packed

def __len__(self):
# 估算长度(不精确,但足够用于进度条等)
# 实际长度取决于打包效率,这里返回一个合理估计
if self._length is not None:
return self._length
if hasattr(self.dataloader, '__len__'):
# 假设平均打包效率为 85%
avg_samples_per_pack = max(1, self.packing_length // 512) # 假设平均样本长度 512
self._length = max(1, len(self.dataloader) // avg_samples_per_pack)
return self._length
return 0
4 changes: 4 additions & 0 deletions swift/llm/template/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,10 @@ def encode(self,
encoded['template_inputs'] = chosen
if not self.remove_unused_columns:
encoded['_extra_kwargs'] = chosen.extra_kwargs
# Preserve dataset source for progress tracking (priority: dataset_name > _dataset_source)
dataset_source = chosen.extra_kwargs.get('dataset_name') or chosen.extra_kwargs.get('_dataset_source')
if dataset_source:
encoded['_dataset_source'] = dataset_source
return batched[0] if len(batched) == 1 else batched

def packing_row(self, row: List[Dict[str, Any]]) -> Dict[str, Any]:
Expand Down
Loading