Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 35 additions & 24 deletions tests/test_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,39 @@

from tests.utils import Payload, empty_strided, randint_strided, randn_strided

_INT_DTYPES = (
torch.int16,
torch.uint16,
torch.int32,
torch.uint32,
torch.int64,
torch.uint64,
_INT_DTYPES = tuple(
d
for d in (
torch.int16,
torch.int32,
torch.int64,
)
if d is not None
)

_UINT_DTYPES = tuple(
d
for d in (
getattr(torch, "uint16", None),
getattr(torch, "uint32", None),
getattr(torch, "uint64", None),
)
if d is not None
)

def _dtype_parametrize():
candidates = [
(torch.float32, 1e-7, 1e-7),
(torch.float16, 1e-3, 1e-3),
(torch.bfloat16, 1e-2, 5e-3),
(torch.int16, 0, 0),
(torch.int32, 0, 0),
(getattr(torch, "uint32", None), 0, 0),
(torch.int64, 0, 0),
(getattr(torch, "uint64", None), 0, 0),
]
return tuple((d, r, a) for (d, r, a) in candidates if d is not None)


@pytest.mark.auto_act_and_assert
@pytest.mark.parametrize(
Expand All @@ -32,22 +56,9 @@
((4, 4, 5632), (45056, 5632, 1), (45056, 5632, 1), (45056, 5632, 1)),
),
)
@pytest.mark.parametrize(
("dtype", "rtol", "atol"),
(
(torch.float32, 1e-7, 1e-7),
(torch.float16, 1e-3, 1e-3),
(torch.bfloat16, 1e-2, 5e-3),
(torch.int16, 0, 0),
(torch.uint16, 0, 0),
(torch.int32, 0, 0),
(torch.uint32, 0, 0),
(torch.int64, 0, 0),
(torch.uint64, 0, 0),
),
)
@pytest.mark.parametrize(("dtype", "rtol", "atol"), _dtype_parametrize())
def test_add(shape, input_strides, other_strides, out_strides, dtype, device, rtol, atol):
if dtype in _INT_DTYPES:
if dtype in _INT_DTYPES or dtype in _UINT_DTYPES:
input = randint_strided(0, 100, shape, input_strides, dtype=dtype, device=device)
other = randint_strided(0, 100, shape, other_strides, dtype=dtype, device=device)
else:
Expand All @@ -66,10 +77,10 @@ def _add(input, other, out):


def _torch_add(input, other, out):
if input.dtype in (torch.uint16, torch.uint32, torch.uint64):
if input.dtype in _UINT_DTYPES:
input = input.to(torch.int64)

if other.dtype in (torch.uint16, torch.uint32, torch.uint64):
if other.dtype in _UINT_DTYPES:
other = other.to(torch.int64)

res = torch.add(input, other)
Expand Down
11 changes: 10 additions & 1 deletion tests/test_rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,4 +59,13 @@ def _rms_norm(input, weight, *, eps=1e-6, out=None):


def _torch_rms_norm(input, weight, *, eps=1e-6, out=None):
return torch.nn.functional.rms_norm(input, input.shape[-1:], weight=weight, eps=eps)
rms_norm_fn = getattr(torch.nn.functional, "rms_norm", None)
if rms_norm_fn is not None:
return rms_norm_fn(input, input.shape[-1:], weight=weight, eps=eps)
# Fallback for PyTorch < 2.3: RMS norm = (x / sqrt(mean(x^2) + eps)) * weight
rms = torch.sqrt(torch.mean(input * input, dim=-1, keepdim=True) + eps)
result = (input / rms) * weight
if out is not None:
out.copy_(result)
return out
return result