Skip to content
35 changes: 32 additions & 3 deletions superbench/benchmarks/model_benchmarks/megatron_gpt3.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,7 +630,7 @@ def _init_distributed_setting(self):
f'--node_rank {node_rank} --master_addr {addr} --master_port {port}'
return True

def _generate_dataset(self):
def _generate_dataset(self): # noqa: C901
"""Generate dataset for benchmarking.

Return:
Expand All @@ -651,13 +651,42 @@ def _generate_dataset(self):
if self._args.dataset_url:
self._raw_data_path = str(Path(self._args.data_home) / 'data.json')
download_file(self._args.dataset_url, self._raw_data_path)

# Megatron's preprocess_data.py appends '_text_document' to --output-prefix
# when producing the .bin/.idx files. For the existence check below
# (which looks for {data_prefix}.bin/.idx) to pass, data_prefix must end
# with '_text_document' and have a non-empty stem when generation is needed.
suffix = '_text_document'
if not self._args.data_prefix.endswith(suffix) or self._args.data_prefix == suffix:
logger.error(
'data_prefix must end with "{}" and have a non-empty stem when '
'dataset generation is required (got "{}"). preprocess_data.py '
'always appends "{}" to --output-prefix.'.format(suffix, self._args.data_prefix, suffix)
)
self._result.set_return_code(ReturnCode.DATASET_GENERATION_FAILURE)
return False
output_prefix_basename = self._args.data_prefix[:-len(suffix)]
output_prefix = os.path.join(self._args.data_home, output_prefix_basename)

# num_workers=0 is valid for DataLoader (main process loads data),
# but preprocess_data.py requires workers>=1 for multiprocessing.Pool.
preprocess_workers = max(1, self._args.num_workers)
if preprocess_workers != self._args.num_workers:
logger.warning(
'preprocess_data.py requires --workers >= 1; '
'overriding num_workers={} to {} for dataset preprocessing only '
'(DataLoader still uses num_workers={}).'.format(
self._args.num_workers, preprocess_workers, self._args.num_workers
)
)

command = (
'python3 '
f'{os.path.join(self._args.code_base, "tools/preprocess_data.py")} '
f'--input {self._raw_data_path} '
f'--tokenizer-type {self._args.tokenizer_type} '
f'--output-prefix {os.path.join(self._args.data_home, "dataset")} '
f'--workers {str(self._args.num_workers)} '
f'--output-prefix {output_prefix} '
f'--workers {preprocess_workers} '
f'--vocab-file {self._vocab_path} '
Comment thread
polarG marked this conversation as resolved.
f'--merge-file {self._merges_path}'
)
Expand Down
94 changes: 94 additions & 0 deletions tests/benchmarks/model_benchmarks/test_megatron_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,100 @@ def test_megatron_gpt_dataset(self):
ret = benchmark._generate_dataset()
assert (ret is True)

@mock.patch('superbench.benchmarks.model_benchmarks.megatron_gpt3.run_command')
@mock.patch('superbench.benchmarks.model_benchmarks.megatron_gpt3.download_file')
def test_megatron_gpt_dataset_generate_command(self, mock_download_file, mock_run_command):
"""Verify _generate_dataset clamps --workers to >=1 and derives --output-prefix from data_prefix."""
(benchmark_cls, _) = BenchmarkRegistry._BenchmarkRegistry__select_benchmark(self.benchmark_name, Platform.CUDA)
assert (benchmark_cls)
os.environ['OMPI_COMM_WORLD_SIZE'] = '1'
os.environ['OMPI_COMM_WORLD_LOCAL_SIZE'] = '1'
os.environ['OMPI_COMM_WORLD_RANK'] = '0'
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12345'

# Use a real, valid code_base so _preprocess() can validate it (avoid hardcoded /root path).
# Clean up after this test so the alphabetically-later test_megatron_gpt_preprocess
# (which expects pretrain_gpt.py to NOT exist initially) is not affected by leaked state.
self.createMockFiles(['pretrain_gpt.py'])
pretrain_path = Path(self._tmp_dir) / 'pretrain_gpt.py'
self.addCleanup(lambda: pretrain_path.unlink() if pretrain_path.is_file() else None)

# Helper: make run_command's side_effect create the expected .bin/.idx files
# so _generate_dataset() (invoked from within _preprocess()) succeeds.
created_files = []

def _make_dataset_files(prefix):
def _side_effect(*_args, **_kwargs):
for ext in ('.bin', '.idx'):
p = Path(self._tmp_dir) / f'{prefix}{ext}'
p.touch()
created_files.append(p)

return _side_effect

self.addCleanup(lambda: [p.unlink() for p in created_files if p.is_file()])

def _build_benchmark(extra_params):
return benchmark_cls(
self.benchmark_name,
parameters=(
f'--code_base {self._tmp_dir} --data_home {self._tmp_dir} '
f'--batch_size 2048 --dataset_url http://example.com/data.json '
f'{extra_params}'
),
)

def _run_case(extra_params, expected_workers, expected_prefix_basename, expected_data_prefix):
mock_run_command.reset_mock()
mock_run_command.side_effect = _make_dataset_files(expected_data_prefix)
benchmark = _build_benchmark(extra_params)
assert benchmark._preprocess() is True
assert mock_run_command.call_count >= 1
# Use tuple indexing instead of `.args` for Python 3.7 compatibility
# (mock.call.args was added in Python 3.8).
cmd = mock_run_command.call_args_list[0][0][0]
units = normalize_command(cmd)
assert f'--workers {expected_workers}' in units, units
expected_output_prefix = os.path.join(self._tmp_dir, expected_prefix_basename)
assert f'--output-prefix {expected_output_prefix}' in units, units

def _run_invalid_case(extra_params):
"""Assert _preprocess() fails fast (no run_command call) for invalid data_prefix."""
mock_run_command.reset_mock()
mock_run_command.side_effect = None
benchmark = _build_benchmark(extra_params)
assert benchmark._preprocess() is False
assert mock_run_command.call_count == 0
assert benchmark.return_code == ReturnCode.DATASET_GENERATION_FAILURE

# Case 1: num_workers=0 with default data_prefix should produce '--workers 1' (clamped)
# and '--output-prefix <data_home>/dataset' (default 'dataset_text_document' suffix stripped).
_run_case(
extra_params='--num_workers 0',
expected_workers=1,
expected_prefix_basename='dataset',
expected_data_prefix='dataset_text_document',
)

# Case 2: num_workers=4 with custom data_prefix='custom_text_document' should produce
# '--workers 4' and '--output-prefix <data_home>/custom'.
_run_case(
extra_params='--num_workers 4 --data_prefix custom_text_document',
expected_workers=4,
expected_prefix_basename='custom',
expected_data_prefix='custom_text_document',
)

# Case 3: data_prefix without the '_text_document' suffix is invalid for generation
# because preprocess_data.py would produce 'mydata_text_document.bin/.idx' but the
# existence check looks for 'mydata.bin/.idx'. _preprocess() must fail fast.
_run_invalid_case(extra_params='--num_workers 2 --data_prefix mydata')

# Case 4: data_prefix == '_text_document' has an empty stem after stripping the suffix,
# which would produce a malformed '--output-prefix <data_home>/'. Must fail fast.
_run_invalid_case(extra_params='--num_workers 1 --data_prefix _text_document')

@mock.patch('superbench.benchmarks.model_benchmarks.MegatronGPT._generate_dataset')
def test_megatron_gpt_command(self, mock_generate_dataset):
"""Test command generation."""
Expand Down
Loading