Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 32 additions & 21 deletions src/maxtext/utils/max_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,8 @@ def maybe_initialize_jax_distributed_system(raw_keys):

For CPUs, we call jax.distributed.initialize() explicitly, with the specified arguments.
"""

# Early exit for cases where we don't need to initialize the jax distributed system.
if raw_keys["skip_jax_distributed_system"]:
max_logging.log("Skipping jax distributed system due to skip_jax_distributed_system=True flag.")
return
Expand All @@ -248,45 +250,54 @@ def maybe_initialize_jax_distributed_system(raw_keys):
if jax.distributed.is_initialized():
max_logging.log("Jax distributed system is already initialized.")
return
if raw_keys["inference_benchmark_test"]:
# Disable initialization for inference benmark test.
return
if raw_keys["compile_topology"]:
# Don't initialize jax distributed with AOT compilation
if raw_keys["inference_benchmark_test"] or raw_keys["compile_topology"]:
max_logging.log("Skipping jax distributed system initialization.")
return

# Initialization for gpu backend
if is_gpu_backend(raw_keys):
max_logging.log("Attempting to initialize the jax distributed system for GPU backend...")
initialize_jax_for_gpu(raw_keys)
max_logging.log("Jax distributed system initialized on GPU!")
elif is_cpu_backend(raw_keys):
return

# Initialization for cpu backend
if is_cpu_backend(raw_keys):
max_logging.log("Attempting to initialize the jax distributed system for CPU backend...")
initialize_jax_for_cpu(raw_keys)
max_logging.log("Jax distributed system initialized on CPUs!")
elif raw_keys["enable_multi_tier_checkpointing"]:
max_logging.log("Attempting to initialize the jax distributed system for multi-tier " "checkpointing...")
return

# Initialization for gpu_multiprocess hardware
if raw_keys["hardware"] == "gpu_multiprocess":
max_logging.log("Attempting to initialize the jax distributed system for gpu_multiprocess hardware...")
if not raw_keys["enable_emergency_checkpoint"]:
jax.distributed.initialize(initialization_timeout=raw_keys["jax_distributed_initialization_timeout"])
else:
max_logging.log("Initializing jax distributed to support local checkpointing with GPUs...")
jax.distributed.initialize(initialization_timeout=raw_keys["jax_distributed_initialization_timeout"])
ocp.multihost.initialize_runtime_to_distributed_ids()
ocp.multihost.initialize_distributed_to_device_ids()
max_logging.log("Jax distributed system initialized!")
return

# Initialization for tpu backend
max_logging.log("Attempting to initialize the jax distributed system for TPU backend...")
if raw_keys["enable_multi_tier_checkpointing"]:
initialize_multi_tier_checkpointing(
local_checkpoint_directory=raw_keys["local_checkpoint_directory"],
backup_interval_minutes=raw_keys["multi_tier_checkpointing_backup_interval_minutes"],
run_name=raw_keys["run_name"],
jax_initialization_timeout_seconds=raw_keys["jax_distributed_initialization_timeout"],
data_parallelism=raw_keys["mtc_data_parallelism"],
)
max_logging.log("Jax distributed system initialized for multi-tier checkpointing!")
elif (raw_keys["enable_checkpointing"] and raw_keys["compile_topology_num_slices"] == -1) or raw_keys[
"hardware"
] == "gpu_multiprocess":
max_logging.log("Attempting to initialize the jax distributed system...")
max_logging.log("Jax distributed system initialized on TPUs for multi-tier checkpointing!")
elif raw_keys["enable_checkpointing"] and raw_keys["compile_topology_num_slices"] == -1:
if not raw_keys["enable_emergency_checkpoint"]:
jax.distributed.initialize(initialization_timeout=raw_keys["jax_distributed_initialization_timeout"])
else:
if raw_keys["hardware"] == "gpu_multiprocess":
max_logging.log("Initializing jax distribtued to support local checkpointing with" " GPUs...")
jax.distributed.initialize(initialization_timeout=raw_keys["jax_distributed_initialization_timeout"])
ocp.multihost.initialize_runtime_to_distributed_ids()
ocp.multihost.initialize_distributed_to_device_ids()
else:
initialize_jax_for_tpu_with_emergency_checkpointing(raw_keys)
max_logging.log("Jax distributed system initialized!")
initialize_jax_for_tpu_with_emergency_checkpointing(raw_keys)
max_logging.log("Jax distributed system initialized on TPUs!")


def initialize_jax_for_gpu(raw_keys):
Expand Down
117 changes: 117 additions & 0 deletions tests/unit/max_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,5 +344,122 @@ def test_regular_shape_unpadded(self):
self.assertEqual(padding_amount, target_padding_amount)


class TestMaybeInitializeJaxDistributedSystem(unittest.TestCase):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there may be weird interactions with pytest, - it looks like this is heavily mocked, so jax.distributed.initialize is never really called? Could you include a comment if this is the case. If you don't mock I am not sure this will work

"""Tests for maybe_initialize_jax_distributed_system."""

def setUp(self):
patcher = mock.patch("jax.distributed.is_initialized", return_value=False)
self.mock_is_initialized = patcher.start()
self.addCleanup(patcher.stop)

def _base_keys(self, **overrides):
"""Return a minimal raw_keys dict with all required fields, accepting per-test overrides."""
keys = {
"skip_jax_distributed_system": False,
"enable_single_controller": False,
"inference_benchmark_test": False,
"compile_topology": False,
"hardware": "tpu",
"enable_emergency_checkpoint": False,
"jax_distributed_initialization_timeout": 300,
"enable_multi_tier_checkpointing": False,
"local_checkpoint_directory": "/tmp/ckpt",
"multi_tier_checkpointing_backup_interval_minutes": 5,
"run_name": "test_run",
"mtc_data_parallelism": 1,
"enable_checkpointing": True,
"compile_topology_num_slices": -1,
}
keys.update(overrides)
return keys

@mock.patch("jax.distributed.initialize")
def test_skip_flag_exits_early(self, mock_init):
max_utils.maybe_initialize_jax_distributed_system(self._base_keys(skip_jax_distributed_system=True))
mock_init.assert_not_called()

@mock.patch("jax.distributed.initialize")
def test_single_controller_exits_early(self, mock_init):
max_utils.maybe_initialize_jax_distributed_system(self._base_keys(enable_single_controller=True))
mock_init.assert_not_called()

@mock.patch("jax.distributed.is_initialized", return_value=True)
@mock.patch("jax.distributed.initialize")
def test_already_initialized_exits_early(self, mock_init, _is_init):
max_utils.maybe_initialize_jax_distributed_system(self._base_keys())
mock_init.assert_not_called()

@mock.patch("jax.distributed.initialize")
def test_inference_benchmark_exits_early(self, mock_init):
max_utils.maybe_initialize_jax_distributed_system(self._base_keys(inference_benchmark_test=True))
mock_init.assert_not_called()

@mock.patch("jax.distributed.initialize")
def test_compile_topology_exits_early(self, mock_init):
max_utils.maybe_initialize_jax_distributed_system(self._base_keys(compile_topology=True))
mock_init.assert_not_called()

@mock.patch("maxtext.utils.max_utils.initialize_jax_for_gpu")
def test_gpu_backend_calls_initialize_jax_for_gpu(self, mock_gpu_init):
raw_keys = self._base_keys(hardware="gpu")
max_utils.maybe_initialize_jax_distributed_system(raw_keys)
mock_gpu_init.assert_called_once_with(raw_keys)

@mock.patch("maxtext.utils.max_utils.initialize_jax_for_cpu")
def test_cpu_backend_calls_initialize_jax_for_cpu(self, mock_cpu_init):
raw_keys = self._base_keys(hardware="cpu")
max_utils.maybe_initialize_jax_distributed_system(raw_keys)
mock_cpu_init.assert_called_once_with(raw_keys)

@mock.patch("jax.distributed.initialize")
def test_gpu_multiprocess_no_emergency_calls_jax_init(self, mock_init):
raw_keys = self._base_keys(hardware="gpu_multiprocess")
max_utils.maybe_initialize_jax_distributed_system(raw_keys)
mock_init.assert_called_once_with(initialization_timeout=self._base_keys()["jax_distributed_initialization_timeout"])

# create=True: initialize_distributed_to_device_ids only exists in multi-host orbax builds, not single-host.
@mock.patch("orbax.checkpoint.multihost.initialize_distributed_to_device_ids", create=True)
@mock.patch("orbax.checkpoint.multihost.initialize_runtime_to_distributed_ids")
@mock.patch("jax.distributed.initialize")
def test_gpu_multiprocess_with_emergency_calls_ocp_multihost(self, mock_init, mock_runtime, mock_device):
raw_keys = self._base_keys(hardware="gpu_multiprocess", enable_emergency_checkpoint=True)
max_utils.maybe_initialize_jax_distributed_system(raw_keys)
mock_init.assert_called_once_with(initialization_timeout=self._base_keys()["jax_distributed_initialization_timeout"])
mock_runtime.assert_called_once()
mock_device.assert_called_once()

@mock.patch("maxtext.utils.max_utils.initialize_multi_tier_checkpointing")
def test_tpu_multi_tier_checkpointing(self, mock_mtc):
raw_keys = self._base_keys(enable_multi_tier_checkpointing=True)
max_utils.maybe_initialize_jax_distributed_system(raw_keys)
mock_mtc.assert_called_once_with(
local_checkpoint_directory=self._base_keys()["local_checkpoint_directory"],
backup_interval_minutes=self._base_keys()["multi_tier_checkpointing_backup_interval_minutes"],
run_name=self._base_keys()["run_name"],
jax_initialization_timeout_seconds=self._base_keys()["jax_distributed_initialization_timeout"],
data_parallelism=self._base_keys()["mtc_data_parallelism"],
)

@mock.patch("jax.distributed.initialize")
def test_tpu_checkpointing_no_emergency_calls_jax_init(self, mock_init):
raw_keys = self._base_keys(enable_checkpointing=True, compile_topology_num_slices=-1)
max_utils.maybe_initialize_jax_distributed_system(raw_keys)
mock_init.assert_called_once_with(initialization_timeout=self._base_keys()["jax_distributed_initialization_timeout"])

@mock.patch("maxtext.utils.max_utils.initialize_jax_for_tpu_with_emergency_checkpointing")
def test_tpu_checkpointing_with_emergency(self, mock_tpu_emergency):
raw_keys = self._base_keys(
enable_checkpointing=True, compile_topology_num_slices=-1, enable_emergency_checkpoint=True
)
max_utils.maybe_initialize_jax_distributed_system(raw_keys)
mock_tpu_emergency.assert_called_once_with(raw_keys)

@mock.patch("jax.distributed.initialize")
def test_tpu_no_checkpointing_does_not_call_jax_init(self, mock_init):
raw_keys = self._base_keys(enable_checkpointing=False, enable_multi_tier_checkpointing=False)
max_utils.maybe_initialize_jax_distributed_system(raw_keys)
mock_init.assert_not_called()


if __name__ == "__main__":
unittest.main()
Loading