Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions src/ntops/torch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,55 @@
import ntops


class _CachedMakeDefaultConfig:
def __init__(self, num_warps=None, num_stages=None, max_num_configs=None):
self.num_warps = num_warps

self.num_stages = num_stages

self.max_num_configs = max_num_configs


_cached_make_default_config = _CachedMakeDefaultConfig()


def get_default_num_warps():
return _cached_make_default_config.num_warps


def set_default_num_warps(num_warps):
_cached_make_default_config.num_warps = num_warps


def get_default_num_stages():
return _cached_make_default_config.num_stages


def set_default_num_stages(num_stages):
_cached_make_default_config.num_stages = num_stages


def get_default_max_num_configs():
return _cached_make_default_config.max_num_configs


def set_default_max_num_configs(max_num_configs):
_cached_make_default_config.max_num_configs = max_num_configs


@functools.cache
def _cached_make(
premake, *args, num_warps=None, num_stages=None, max_num_configs=None, **keywords
):
if num_warps is None:
num_warps = _cached_make_default_config.num_warps

if num_stages is None:
num_stages = _cached_make_default_config.num_stages

if max_num_configs is None:
max_num_configs = _cached_make_default_config.max_num_configs

return ninetoothed.make(
*premake(*args, **keywords),
num_warps=num_warps,
Expand Down
7 changes: 7 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,15 @@
import pytest
import torch

import ntops.torch.utils


def pytest_configure():
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False

ntops.torch.utils.set_default_max_num_configs(_DEFAULT_MAX_NUM_CONFIGS)


def pytest_collectstart(collector):
if isinstance(collector, pytest.Module):
Expand All @@ -25,6 +29,9 @@ def set_seed_per_test(request):
_set_random_seed(_hash(_test_case_path_from_request(request)))


_DEFAULT_MAX_NUM_CONFIGS = 3


def _set_random_seed(seed):
random.seed(seed)
torch.manual_seed(seed)
Expand Down
6 changes: 2 additions & 4 deletions tests/test_abs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,10 @@

@skip_if_cuda_not_available
@pytest.mark.parametrize(*generate_arguments())
def test_cuda(shape, dtype, atol, rtol):
device = "cuda"

def test_abs(shape, dtype, device, rtol, atol):
input = torch.randn(shape, dtype=dtype, device=device)

ninetoothed_output = ntops.torch.abs(input)
reference_output = torch.abs(input)

assert torch.allclose(ninetoothed_output, reference_output, atol=atol, rtol=rtol)
assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol)
6 changes: 2 additions & 4 deletions tests/test_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,12 @@

@skip_if_cuda_not_available
@pytest.mark.parametrize(*generate_arguments())
def test_cuda(shape, dtype, atol, rtol):
device = "cuda"

def test_add(shape, dtype, device, rtol, atol):
input = torch.randn(shape, dtype=dtype, device=device)
other = torch.randn(shape, dtype=dtype, device=device)
alpha = gauss()

ninetoothed_output = ntops.torch.add(input, other, alpha=alpha)
reference_output = torch.add(input, other, alpha=alpha)

assert torch.allclose(ninetoothed_output, reference_output, atol=atol, rtol=rtol)
assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol)
6 changes: 2 additions & 4 deletions tests/test_addmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,7 @@

@skip_if_cuda_not_available
@pytest.mark.parametrize(*generate_arguments())
def test_cuda(m, n, k, dtype, atol, rtol):
device = "cuda"

def test_addmm(m, n, k, dtype, device, rtol, atol):
input = torch.randn((m, n), dtype=dtype, device=device)
x = torch.randn((m, k), dtype=dtype, device=device)
y = torch.randn((k, n), dtype=dtype, device=device)
Expand All @@ -21,4 +19,4 @@ def test_cuda(m, n, k, dtype, atol, rtol):
ninetoothed_output = ntops.torch.addmm(input, x, y, beta=beta, alpha=alpha)
reference_output = torch.addmm(input, x, y, beta=beta, alpha=alpha)

assert torch.allclose(ninetoothed_output, reference_output, atol=atol, rtol=rtol)
assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol)
4 changes: 1 addition & 3 deletions tests/test_bitwise_and.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,7 @@

@skip_if_cuda_not_available
@pytest.mark.parametrize(*generate_arguments(False))
def test_cuda(shape, dtype, atol, rtol):
device = "cuda"

def test_bitwise_and(shape, dtype, device, rtol, atol):
if dtype == torch.bool:
prob = 0.5
input = torch.rand(shape, dtype=torch.float32, device=device) > prob
Expand Down
4 changes: 1 addition & 3 deletions tests/test_bitwise_not.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,7 @@

@skip_if_cuda_not_available
@pytest.mark.parametrize(*generate_arguments(False))
def test_cuda(shape, dtype, atol, rtol):
device = "cuda"

def test_bitwise_not(shape, dtype, device, rtol, atol):
if dtype == torch.bool:
prob = 0.5
input = torch.rand(shape, dtype=torch.float32, device=device) > prob
Expand Down
4 changes: 1 addition & 3 deletions tests/test_bitwise_or.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,7 @@

@skip_if_cuda_not_available
@pytest.mark.parametrize(*generate_arguments(False))
def test_cuda(shape, dtype, atol, rtol):
device = "cuda"

def test_bitwise_or(shape, dtype, device, rtol, atol):
if dtype == torch.bool:
prob = 0.5
input = torch.rand(shape, dtype=torch.float32, device=device) > prob
Expand Down
6 changes: 2 additions & 4 deletions tests/test_bmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,12 @@

@skip_if_cuda_not_available
@pytest.mark.parametrize(*generate_arguments())
def test_cuda(m, n, k, dtype, atol, rtol):
device = "cuda"

def test_bmm(m, n, k, dtype, device, rtol, atol):
b = random.randint(4, 16)
input = torch.randn((b, m, k), dtype=dtype, device=device)
other = torch.randn((b, k, n), dtype=dtype, device=device)

ninetoothed_output = ntops.torch.bmm(input, other)
reference_output = torch.bmm(input, other)

assert torch.allclose(ninetoothed_output, reference_output, atol=atol, rtol=rtol)
assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol)
6 changes: 2 additions & 4 deletions tests/test_clamp.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,12 @@

@skip_if_cuda_not_available
@pytest.mark.parametrize(*generate_arguments())
def test_cuda(shape, dtype, atol, rtol):
device = "cuda"

def test_clamp(shape, dtype, device, rtol, atol):
input = torch.randn(shape, dtype=dtype, device=device)
min = torch.randn(shape, dtype=dtype, device=device)
max = torch.randn(shape, dtype=dtype, device=device)

ninetoothed_output = ntops.torch.clamp(input, min, max)
reference_output = torch.clamp(input, min, max)

assert torch.allclose(ninetoothed_output, reference_output, atol=atol, rtol=rtol)
assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol)
7 changes: 2 additions & 5 deletions tests/test_cos.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,13 @@

@skip_if_cuda_not_available
@pytest.mark.parametrize(*generate_arguments())
def test_cuda(shape, dtype, atol, rtol):
def test_cos(shape, dtype, device, rtol, atol):
# TODO: Test for `float16` later.
if dtype is torch.float16:
return

device = "cuda"

input = torch.randn(shape, dtype=dtype, device=device)

ninetoothed_output = ntops.torch.cos(input)
reference_output = torch.cos(input)

assert torch.allclose(ninetoothed_output, reference_output, atol=atol, rtol=rtol)
assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol)
27 changes: 14 additions & 13 deletions tests/test_div.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,22 @@


@skip_if_cuda_not_available
@pytest.mark.parametrize(
"rounding_mode",
[
None,
pytest.param(
"trunc", marks=pytest.mark.skip(reason="TODO: Test for `trunc` mode later.")
),
"floor",
],
)
@pytest.mark.parametrize(*generate_arguments())
def test_cuda(shape, dtype, atol, rtol):
device = "cuda"

def test_div(shape, rounding_mode, dtype, device, rtol, atol):
input = torch.randn(shape, dtype=dtype, device=device)
other = torch.randn(shape, dtype=dtype, device=device)

for rounding_mode in (None, "trunc", "floor"):
# TODO: Test for `trunc` mode later.
if rounding_mode == "trunc":
continue

ninetoothed_output = ntops.torch.div(input, other, rounding_mode=rounding_mode)
reference_output = torch.div(input, other, rounding_mode=rounding_mode)
ninetoothed_output = ntops.torch.div(input, other, rounding_mode=rounding_mode)
reference_output = torch.div(input, other, rounding_mode=rounding_mode)

assert torch.allclose(
ninetoothed_output, reference_output, atol=atol, rtol=rtol
)
assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol)
4 changes: 1 addition & 3 deletions tests/test_dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,7 @@

@skip_if_cuda_not_available
@pytest.mark.parametrize(*generate_arguments())
def test_cuda(shape, dtype, atol, rtol):
device = "cuda"

def test_dropout(shape, dtype, device, rtol, atol):
input = torch.randn(shape, dtype=dtype, device=device)
p = random.uniform(0, 1)

Expand Down
4 changes: 1 addition & 3 deletions tests/test_eq.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,7 @@

@skip_if_cuda_not_available
@pytest.mark.parametrize(*generate_arguments())
def test_cuda(shape, dtype, atol, rtol):
device = "cuda"

def test_eq(shape, dtype, device, rtol, atol):
input = torch.randn(shape, dtype=dtype, device=device)
other = torch.randn(shape, dtype=dtype, device=device)

Expand Down
7 changes: 2 additions & 5 deletions tests/test_exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,13 @@

@skip_if_cuda_not_available
@pytest.mark.parametrize(*generate_arguments())
def test_cuda(shape, dtype, atol, rtol):
def test_exp(shape, dtype, device, rtol, atol):
# TODO: Test for `float16` later.
if dtype is torch.float16:
return

device = "cuda"

input = torch.randn(shape, dtype=dtype, device=device)

ninetoothed_output = ntops.torch.exp(input)
reference_output = torch.exp(input)

assert torch.allclose(ninetoothed_output, reference_output, atol=atol, rtol=rtol)
assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol)
4 changes: 1 addition & 3 deletions tests/test_ge.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,7 @@

@skip_if_cuda_not_available
@pytest.mark.parametrize(*generate_arguments())
def test_cuda(shape, dtype, atol, rtol):
device = "cuda"

def test_ge(shape, dtype, device, rtol, atol):
input = torch.randn(shape, dtype=dtype, device=device)
other = torch.randn(shape, dtype=dtype, device=device)

Expand Down
22 changes: 13 additions & 9 deletions tests/test_gelu.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,20 @@


@skip_if_cuda_not_available
@pytest.mark.parametrize(
"approximate",
(
"none",
pytest.param(
"tanh", marks=pytest.mark.skip(reason="TODO: Test for `tanh` mode later.")
),
),
)
@pytest.mark.parametrize(*generate_arguments())
def test_cuda(shape, dtype, atol, rtol):
device = "cuda"

def test_gelu(shape, approximate, dtype, device, rtol, atol):
input = torch.randn(shape, dtype=dtype, device=device)

for approximate in ("none", "tanh"):
ninetoothed_output = ntops.torch.gelu(input)
reference_output = F.gelu(input)
ninetoothed_output = ntops.torch.gelu(input, approximate=approximate)
reference_output = F.gelu(input, approximate=approximate)

assert torch.allclose(
ninetoothed_output, reference_output, atol=atol, rtol=rtol
)
assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol)
4 changes: 1 addition & 3 deletions tests/test_gt.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,7 @@

@skip_if_cuda_not_available
@pytest.mark.parametrize(*generate_arguments())
def test_cuda(shape, dtype, atol, rtol):
device = "cuda"

def test_gt(shape, dtype, device, rtol, atol):
input = torch.randn(shape, dtype=dtype, device=device)
other = torch.randn(shape, dtype=dtype, device=device)

Expand Down
4 changes: 1 addition & 3 deletions tests/test_isinf.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,7 @@

@skip_if_cuda_not_available
@pytest.mark.parametrize(*generate_arguments())
def test_cuda(shape, dtype, atol, rtol):
device = "cuda"

def test_isinf(shape, dtype, device, rtol, atol):
def generate_inf_tensor(shape, dtype, device):
x = torch.randn(shape, dtype=dtype, device=device)

Expand Down
4 changes: 1 addition & 3 deletions tests/test_isnan.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,7 @@

@skip_if_cuda_not_available
@pytest.mark.parametrize(*generate_arguments())
def test_cuda(shape, dtype, atol, rtol):
device = "cuda"

def test_isnan(shape, dtype, device, rtol, atol):
def generate_nan_tensor(shape, dtype, device):
nan_prob = 0.4
prob_tensor = torch.rand(shape, device=device)
Expand Down
8 changes: 4 additions & 4 deletions tests/test_layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
@pytest.mark.parametrize("bias_is_none", (False, True))
@pytest.mark.parametrize("weight_is_none", (False, True))
@pytest.mark.parametrize(*generate_arguments())
def test_cuda(shape, dtype, atol, rtol, weight_is_none, bias_is_none, eps):
device = "cuda"

def test_layer_norm(
shape, dtype, device, rtol, atol, weight_is_none, bias_is_none, eps
):
input = torch.randn(shape, dtype=dtype, device=device)
normalized_shape = shape[-random.randint(1, len(shape)) :]
if weight_is_none:
Expand All @@ -34,4 +34,4 @@ def test_cuda(shape, dtype, atol, rtol, weight_is_none, bias_is_none, eps):
input, normalized_shape, weight=weight, bias=bias, eps=eps
)

assert torch.allclose(ninetoothed_output, reference_output, atol=atol, rtol=rtol)
assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol)
4 changes: 1 addition & 3 deletions tests/test_le.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,7 @@

@skip_if_cuda_not_available
@pytest.mark.parametrize(*generate_arguments())
def test_cuda(shape, dtype, atol, rtol):
device = "cuda"

def test_le(shape, dtype, device, rtol, atol):
input = torch.randn(shape, dtype=dtype, device=device)
other = torch.randn(shape, dtype=dtype, device=device)

Expand Down
Loading