Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/src/python/nn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ In detail:
value_and_grad
quantize
average_gradients
fsdp_apply_gradients

.. toctree::

Expand Down
6 changes: 5 additions & 1 deletion python/mlx/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,8 @@

from mlx.nn import init, losses
from mlx.nn.layers import *
from mlx.nn.utils import average_gradients, value_and_grad
from mlx.nn.utils import (
average_gradients,
fsdp_apply_gradients,
value_and_grad,
)
198 changes: 179 additions & 19 deletions python/mlx/nn/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import mlx.core as mx

from ..utils import tree_flatten, tree_map, tree_unflatten
from ..utils import tree_flatten, tree_map, tree_reduce, tree_unflatten
from .layers.base import Module


Expand Down Expand Up @@ -71,6 +71,31 @@ def wrapped_checkpointed_fn(*args, **kwargs):
return wrapped_checkpointed_fn


def _extract_info(flat):
keys = [k for k, _ in flat]
shapes = [g.shape for _, g in flat]
sizes = [g.size for _, g in flat]
dtypes = [g.dtype for _, g in flat]
return keys, shapes, sizes, dtypes


def _group_by_size(keys, sizes, itemsize, communication_size):
grad_groups = []
grad_group = []
grad_group_size = 0
for i in range(len(keys)):
grad_group.append(i)
grad_group_size += sizes[i] * itemsize
if grad_group_size >= communication_size:
grad_groups.append(grad_group)
grad_group = []
grad_group_size = 0
if grad_group:
grad_groups.append(grad_group)
grad_group = []
return grad_groups


def average_gradients(
gradients: Any,
group: Optional[mx.distributed.Group] = None,
Expand All @@ -95,7 +120,7 @@ def average_gradients(
communication_type (Optional[mlx.core.Dtype]): If provided cast to this
type before performing the communication. Typically cast to a
smaller float to reduce the communication size. Default: ``None``.
communication_stream (Optional[mlx.core.Stream]): The stream to usse
communication_stream (Optional[mlx.core.Stream]): The stream to use
for the communication. If unspecified the default communication
stream is used which can vary by back-end. Default: ``None``.
"""
Expand All @@ -119,10 +144,7 @@ def _average(x):
return gradients

# Extract some info for the gradient
keys = [k for k, _ in flat_grads]
shapes = [v.shape for _, v in flat_grads]
sizes = [v.size for _, v in flat_grads]
dtypes = [v.dtype for _, v in flat_grads]
keys, shapes, sizes, dtypes = _extract_info(flat_grads)

# We can't group them if they have mixed types
if not all(dt == dtypes[0] for dt in dtypes):
Expand All @@ -134,19 +156,7 @@ def _average(x):
)

# Gather the gradients in groups that are just above or equal to all_reduce_size
grad_groups = []
grad_group = []
grad_group_size = 0
for i in range(len(keys)):
grad_group.append(i)
grad_group_size += sizes[i] * itemsize
if grad_group_size >= all_reduce_size:
grad_groups.append(grad_group)
grad_group = []
grad_group_size = 0
if grad_group:
grad_groups.append(grad_group)
grad_group = []
grad_groups = _group_by_size(keys, sizes, itemsize, all_reduce_size)

# Concatenate-reduce-split
new_flat_grads = []
Expand All @@ -163,3 +173,153 @@ def _average(x):
)

return tree_unflatten(new_flat_grads)


def _clip_grads_fsdp(grads_slice, max_norm):
local_norm_sq = tree_reduce(lambda acc, g: acc + g.square().sum(), grads_slice, 0.0)
global_norm_sq = mx.distributed.all_sum(local_norm_sq)
grad_norm = mx.sqrt(global_norm_sq)
normalizer = mx.minimum(max_norm / (grad_norm + 1e-6), 1.0)
grads_slice = tree_map(lambda g: g * normalizer, grads_slice)

return grads_slice, grad_norm


def fsdp_apply_gradients(
gradients,
parameters,
optimizer,
group=None,
communication_size=32 * 1024**2,
communication_type=None,
communication_stream=None,
max_norm=None,
):
"""Perform a distributed optimizer step by sharding gradients and optimizer states across ranks.

This helper function performs the following steps:
1. Reduce-scatter the gradients across ranks so each rank gets a shard of the averaged gradients.
2. Optionally clip the sharded gradients by global norm.
3. Apply the optimizer update on the local parameter slice using the sharded gradients.
4. All-gather the updated parameter slices from all ranks to reconstruct the full parameters tree.

This is similar to PyTorch's FSDP with `reshard_after_forward=False`.

Args:
gradients (Any): The Python tree containing the full gradients (it should
have the same structure as ``parameters``). Each gradient's first
dimension must be divisible by the world size.
parameters (Any): The Python tree containing the full parameters (it should
have the same structure across processes). Each parameter's first
dimension must be divisible by the world size.
optimizer: Optimizer with an ``apply_gradients`` method.
group (Optional[mlx.core.distributed.Group]): The group of processes for
communication. If ``None``, the global group is used.
Default: ``None``.
communication_size (int): Group arrays until their size in bytes exceeds
this number. Perform one communication step per group of arrays. If
less or equal to 0 array grouping is disabled. Default: ``32MiB``.
communication_type (Optional[mlx.core.Dtype]): If provided cast to this
type before performing the communication. Typically cast to a
smaller float to reduce the communication size. Default: ``None``.
communication_stream (Optional[mlx.core.Stream]): The stream to use
for the communication. If unspecified the default communication
stream is used which can vary by back-end. Default: ``None``.
max_norm (Optional[float]): If provided, clip gradients to this
maximum global norm before applying the optimizer update.
Default: ``None``.

Returns:
If ``max_norm`` is ``None``, returns the updated full-parameter tree.
Otherwise returns ``(parameters, grad_norm)``, where ``grad_norm`` is
the global gradient norm before clipping.

Example:

>>> optimizer = optim.SGD(learning_rate=0.01)
>>> # Without gradient clipping
>>> updated_params = fsdp_apply_gradients(grads, params, optimizer)
>>> model.update(updated_params)
>>>
>>> # With gradient clipping
>>> updated_params, grad_norm = fsdp_apply_gradients(
... grads, params, optimizer, max_norm=1.0
... )
>>> model.update(updated_params)
"""
group = group or mx.distributed.init()
N = group.size()
rank = group.rank()

if N == 1:
if max_norm is not None:
gradients, grad_norm = _clip_grads_fsdp(gradients, max_norm)
return optimizer.apply_gradients(gradients, parameters), grad_norm
return optimizer.apply_gradients(gradients, parameters)

flat_grads = tree_flatten(gradients)
flat_params = tree_flatten(parameters)

def _sum_scatter(x):
dt = x.dtype
x = x.astype(communication_type) if communication_type is not None else x
return (
mx.distributed.sum_scatter(
x, group=group, stream=communication_stream
).astype(dt)
/ N
)

def _all_gather(x):
dt = x.dtype
x = x.astype(communication_type) if communication_type is not None else x
return mx.distributed.all_gather(
x, group=group, stream=communication_stream
).astype(dt)

keys, shapes, sizes, dtypes = _extract_info(flat_grads)
itemsize = dtypes[0].size

groups = _group_by_size(keys, sizes, itemsize, communication_size)

# reduce-scatter gradients, shard parameters
grad_slices = {}
param_slices = {}
for group_idx, arr_group in enumerate(groups):
big_grad = mx.concatenate(
[flat_grads[i][1].reshape(N, -1) for i in arr_group], axis=1
)
grad_slices[group_idx] = _sum_scatter(big_grad)
big_param = mx.concatenate(
[flat_params[i][1].reshape(N, -1) for i in arr_group], axis=1
)
param_slices[group_idx] = big_param[rank]

# clip gradients if needed
grad_norm = None
if max_norm is not None:
grad_slices, grad_norm = _clip_grads_fsdp(grad_slices, max_norm)

# optimizer step
updated_param_slices = optimizer.apply_gradients(grad_slices, param_slices)

# all-gather and reconstruct
new_flat = []
for group_idx, arr_group in enumerate(groups):
big_gathered = _all_gather(updated_param_slices[group_idx].reshape(1, -1))

split_sizes = [sizes[i] // N for i in arr_group]
split_indices = []
acc = 0
for s in split_sizes:
acc += s
split_indices.append(acc)

parts = mx.split(big_gathered, split_indices[:-1], axis=1)
for idx_in_group, i in enumerate(arr_group):
new_flat.append((keys[i], parts[idx_in_group].reshape(shapes[i])))

result = tree_unflatten(new_flat)
if max_norm is not None:
return result, grad_norm
return result
126 changes: 126 additions & 0 deletions python/tests/nccl_test_distributed.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# Copyright © 2024 Apple Inc.

import mlx.core as mx
import mlx.optimizers as optim
import mlx_distributed_tests
import mlx_tests
from mlx.nn.utils import average_gradients, fsdp_apply_gradients


class TestNCCLDistributed(mlx_distributed_tests.MLXDistributedCommonTestCase):
Expand Down Expand Up @@ -114,6 +116,130 @@ def test_all_gather_split(self):
self.assertEqual(y.shape, (sub.size() * 2, 2, 4))
self.assertTrue(mx.all(y == 1))

def test_fsdp_apply_gradients(self):
world = mx.distributed.init()
N = world.size()

params = {
"w1": mx.ones((N * 10, 8)),
"w2": mx.ones((N * 20,)),
}
grads = {
"w1": mx.ones((N * 10, 8)) * 0.1,
"w2": mx.ones((N * 20,)) * 0.1,
}

optimizer = optim.SGD(learning_rate=0.1)
updated_params_fsdp = fsdp_apply_gradients(grads, params, optimizer)
mx.eval(updated_params_fsdp)

self.assertEqual(updated_params_fsdp["w1"].shape, (N * 10, 8))
self.assertEqual(updated_params_fsdp["w2"].shape, (N * 20,))

self.assertTrue(
mx.allclose(
updated_params_fsdp["w1"], mx.ones((N * 10, 8)) * 0.99, atol=1e-6
)
)
self.assertTrue(
mx.allclose(updated_params_fsdp["w2"], mx.ones((N * 20,)) * 0.99, atol=1e-6)
)

grads = {
"w1": mx.ones((N * 10, 8)) * 10.0,
"w2": mx.ones((N * 20,)) * 10.0,
}

new_params_clipped, grad_norm = fsdp_apply_gradients(
grads, params, optimizer, max_norm=1.0
)
mx.eval(new_params_clipped, grad_norm)

self.assertIsNotNone(grad_norm)
expected_norm = mx.sqrt((N * 10 * 8 + N * 20) * 100.0)
self.assertTrue(mx.allclose(grad_norm, expected_norm, atol=1e-4, rtol=1e-4))
self.assertEqual(new_params_clipped["w1"].shape, (N * 10, 8))
self.assertEqual(new_params_clipped["w2"].shape, (N * 20,))

scale = 1.0 / expected_norm
expected_update = 1.0 - 0.1 * 10.0 * scale
self.assertTrue(
mx.allclose(
new_params_clipped["w1"],
mx.ones((N * 10, 8)) * expected_update,
atol=1e-4,
rtol=1e-4,
)
)
self.assertTrue(
mx.allclose(
new_params_clipped["w2"],
mx.ones((N * 20,)) * expected_update,
atol=1e-4,
rtol=1e-4,
)
)
params = {"w": mx.ones((N * 4,))}
grads = {"w": mx.ones((N * 4,)) * 0.5}

optimizer_fsdp = optim.SGD(learning_rate=0.1)
updated_params_fsdp = fsdp_apply_gradients(grads, params, optimizer_fsdp)

optimizer_ddp = optim.SGD(learning_rate=0.1)
avg_grads = average_gradients(grads)
updated_params_ddp = optimizer_ddp.apply_gradients(avg_grads, params)
mx.eval(updated_params_ddp, updated_params_fsdp)

self.assertTrue(
mx.allclose(
updated_params_fsdp["w"], updated_params_ddp["w"], atol=1e-6, rtol=1e-4
),
)

def test_fsdp_peak_memory(self):
world = mx.distributed.init()
N = world.size()
mx.random.seed(42)
params = {
"w1": mx.random.normal((N * 1024, 1024)),
"w2": mx.random.normal((N * 2048, 512)),
}
grads = {
"w1": mx.random.normal((N * 1024, 1024)),
"w2": mx.random.normal((N * 2048, 512)),
}
mx.eval(params, grads)
optimizer_ddp = optim.Adam(learning_rate=0.01)
optimizer_fsdp = optim.Adam(learning_rate=0.01)

def pseudo_step_ddp(grads, params, optimizer):
grads = average_gradients(grads)
grads, grad_norm = optim.clip_grad_norm(grads, max_norm=1.0)
params = optimizer.apply_gradients(grads, params)
return grad_norm, params

def pseudo_step_fsdp(grads, params, optimizer):
params, grad_norm = fsdp_apply_gradients(
grads, params, optimizer, max_norm=1.0
)
return grad_norm, params

mx.reset_peak_memory()

for i in range(10):
grad_norm, params = pseudo_step_ddp(grads, params, optimizer_ddp)
mx.eval(grad_norm, params)

ddp_peak_memory = mx.get_peak_memory()
mx.reset_peak_memory()

for i in range(10):
grad_norm, params = pseudo_step_fsdp(grads, params, optimizer_fsdp)
mx.eval(grad_norm, params)

fsdp_peak_memory = mx.get_peak_memory()
self.assertTrue(fsdp_peak_memory < ddp_peak_memory)


if __name__ == "__main__":
mlx_tests.MLXTestRunner()
Loading