-
Notifications
You must be signed in to change notification settings - Fork 494
Fix DiLoCo training compatibility issues #3471
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2643,8 +2643,32 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de | |
| self.dcn_parallelism = [dcn_map[axis] for axis in self.mesh_axes] | ||
|
|
||
| # Diloco params | ||
| # Resolve dcn_diloco_parallelism=-1 if left unspecified, using the same convention as dcn_data_parallelism. | ||
| # num_diloco_replicas must be computed after this resolution, so we resolve it here rather than | ||
| # relying on fill_unspecified_mesh_axes (which runs later during mesh creation). | ||
| if self.dcn_diloco_parallelism == -1: | ||
| other_dcn_product = prod(v for v in self.dcn_parallelism if v != -1) | ||
| assert other_dcn_product > 0 and self.num_slices % other_dcn_product == 0, ( | ||
| f"Cannot resolve dcn_diloco_parallelism=-1: num_slices={self.num_slices} is not divisible " | ||
| f"by the product of other DCN parallelism values ({other_dcn_product})." | ||
| ) | ||
| self.dcn_diloco_parallelism = self.num_slices // other_dcn_product | ||
| # Keep dcn_parallelism list consistent with the resolved value. | ||
| diloco_idx = self.dcn_parallelism.index(-1) | ||
| self.dcn_parallelism[diloco_idx] = self.dcn_diloco_parallelism | ||
| self.num_diloco_replicas = int(self.ici_diloco_parallelism * self.dcn_diloco_parallelism) | ||
|
|
||
| # use_tokamax_gmm is incompatible with enable_diloco: drjax.map_fn wraps the train step in | ||
| # jax.vmap over the diloco axis, which causes JAX to batch through lax.scan (layer scan). | ||
| # Tokamax's vmap_rule then tries to reconstruct GroupSizes with a batched 2-D value, but | ||
| # GroupSizes.__post_init__ requires exactly a 1-D shape. | ||
| if self.enable_diloco and self.use_tokamax_gmm: | ||
| raise ValueError( | ||
| "use_tokamax_gmm=True is not compatible with enable_diloco=True due to a known " | ||
| "incompatibility between tokamax's GroupSizes vmap_rule and JAX's scan batching. " | ||
| "Please set use_tokamax_gmm=False." | ||
| ) | ||
|
|
||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder if diloco works with all opt_type? e.g. sgd, muon. Maybe disable them if they don't work. |
||
| # Final string-to-enum conversions if they haven't been coerced by pydantic yet. | ||
| if isinstance(self.decoder_block, str): | ||
| self.decoder_block = DecoderBlockType(self.decoder_block.lower()) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -490,19 +490,18 @@ def train_loop(config, recorder, state=None): | |
|
|
||
| params_shardings, state_mesh_shardings = sharding.maybe_update_params_sharding_with_opt(config, state_mesh_shardings) | ||
|
|
||
| p_train_step, p_eval_step = train_utils.jit_train_and_eval_step( | ||
| config, | ||
| model, | ||
| mesh, | ||
| state, | ||
| state_mesh_shardings, | ||
| train_step, | ||
| eval_step, | ||
| eval_data_iterator, | ||
| params_shardings, | ||
| ) | ||
|
|
||
| with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules): | ||
| with jax.set_mesh(mesh), mesh, nn_partitioning.axis_rules(config.logical_axis_rules): | ||
| p_train_step, p_eval_step = train_utils.jit_train_and_eval_step( | ||
| config, | ||
| model, | ||
| mesh, | ||
| state, | ||
| state_mesh_shardings, | ||
| train_step, | ||
| eval_step, | ||
| eval_data_iterator, | ||
| params_shardings, | ||
| ) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. our jit function was not in the mesh/logical rule context before? what difference did you see by making this change?
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oh I see the comments in train_compile. I think JAX will deprecate with mesh context. |
||
| shaped_batch = maxtext_utils.get_shaped_batch(config) | ||
| if config.shard_optimizer_over_data: | ||
| state = sharding.maybe_shard_with_name(state, state_mesh_shardings, config.shard_mode) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we create a bug for this to track?