Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
592 changes: 0 additions & 592 deletions src/ntops/torch.py

This file was deleted.

77 changes: 77 additions & 0 deletions src/ntops/torch/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
from ntops.torch.abs import abs
from ntops.torch.add import add
from ntops.torch.addmm import addmm
from ntops.torch.bitwise_and import bitwise_and
from ntops.torch.bitwise_not import bitwise_not
from ntops.torch.bitwise_or import bitwise_or
from ntops.torch.bmm import bmm
from ntops.torch.clamp import clamp
from ntops.torch.cos import cos
from ntops.torch.div import div
from ntops.torch.dropout import dropout
from ntops.torch.eq import eq
from ntops.torch.exp import exp
from ntops.torch.ge import ge
from ntops.torch.gelu import gelu
from ntops.torch.gt import gt
from ntops.torch.isinf import isinf
from ntops.torch.isnan import isnan
from ntops.torch.layer_norm import layer_norm
from ntops.torch.le import le
from ntops.torch.lt import lt
from ntops.torch.mm import mm
from ntops.torch.mul import mul
from ntops.torch.ne import ne
from ntops.torch.neg import neg
from ntops.torch.pow import pow
from ntops.torch.relu import relu
from ntops.torch.rms_norm import rms_norm
from ntops.torch.rotary_position_embedding import rotary_position_embedding
from ntops.torch.rsqrt import rsqrt
from ntops.torch.scaled_dot_product_attention import scaled_dot_product_attention
from ntops.torch.sigmoid import sigmoid
from ntops.torch.silu import silu
from ntops.torch.sin import sin
from ntops.torch.softmax import softmax
from ntops.torch.sub import sub
from ntops.torch.tanh import tanh

__all__ = [
"abs",
"add",
"addmm",
"bitwise_and",
"bitwise_not",
"bitwise_or",
"bmm",
"clamp",
"cos",
"div",
"dropout",
"eq",
"exp",
"ge",
"gelu",
"gt",
"isinf",
"isnan",
"layer_norm",
"le",
"lt",
"mm",
"mul",
"ne",
"neg",
"pow",
"relu",
"rms_norm",
"rotary_position_embedding",
"rsqrt",
"scaled_dot_product_attention",
"sigmoid",
"silu",
"sin",
"softmax",
"sub",
"tanh",
]
15 changes: 15 additions & 0 deletions src/ntops/torch/abs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import torch

import ntops
from ntops.torch.utils import _cached_make


def abs(input, *, out=None):
if out is None:
out = torch.empty_like(input)

kernel = _cached_make(ntops.kernels.abs.premake, input.ndim)

kernel(input, out)

return out
15 changes: 15 additions & 0 deletions src/ntops/torch/add.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import torch

import ntops
from ntops.torch.utils import _cached_make


def add(input, other, *, alpha=1, out=None):
if out is None:
out = torch.empty_like(input)

kernel = _cached_make(ntops.kernels.add.premake, input.ndim)

kernel(input, other, alpha, out)

return out
18 changes: 18 additions & 0 deletions src/ntops/torch/addmm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import torch

import ntops
from ntops.torch.utils import _cached_make, _get_matmul_input_precision


def addmm(input, mat1, mat2, *, beta=1, alpha=1, out=None):
m, _ = mat1.shape
_, n = mat2.shape

if out is None:
out = torch.empty((m, n), dtype=input.dtype, device=input.device)

kernel = _cached_make(ntops.kernels.addmm.premake)

kernel(input, mat1, mat2, beta, alpha, out, _get_matmul_input_precision())

return out
15 changes: 15 additions & 0 deletions src/ntops/torch/bitwise_and.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import torch

import ntops
from ntops.torch.utils import _cached_make


def bitwise_and(input, other, *, out=None):
if out is None:
out = torch.empty_like(input)

kernel = _cached_make(ntops.kernels.bitwise_and.premake, input.ndim)

kernel(input, other, out)

return out
17 changes: 17 additions & 0 deletions src/ntops/torch/bitwise_not.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import torch

import ntops
from ntops.torch.utils import _cached_make


def bitwise_not(input, *, out=None):
if out is None:
out = torch.empty_like(input)

kernel = _cached_make(
ntops.kernels.bitwise_not.premake, input.ndim, input.dtype == torch.bool
)

kernel(input, out)

return out
15 changes: 15 additions & 0 deletions src/ntops/torch/bitwise_or.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import torch

import ntops
from ntops.torch.utils import _cached_make


def bitwise_or(input, other, *, out=None):
if out is None:
out = torch.empty_like(input)

kernel = _cached_make(ntops.kernels.bitwise_or.premake, input.ndim)

kernel(input, other, out)

return out
18 changes: 18 additions & 0 deletions src/ntops/torch/bmm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import torch

import ntops
from ntops.torch.utils import _cached_make, _get_matmul_input_precision


def bmm(input, mat2, *, out=None):
b, m, _ = input.shape
_, _, n = mat2.shape

if out is None:
out = torch.empty((b, m, n), dtype=input.dtype, device=input.device)

kernel = _cached_make(ntops.kernels.bmm.premake)

kernel(input, mat2, out, _get_matmul_input_precision())

return out
15 changes: 15 additions & 0 deletions src/ntops/torch/clamp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import torch

import ntops
from ntops.torch.utils import _cached_make


def clamp(input, min=None, max=None, *, out=None):
if out is None:
out = torch.empty_like(input)

kernel = _cached_make(ntops.kernels.clamp.premake, input.ndim)

kernel(input, min, max, out)

return out
15 changes: 15 additions & 0 deletions src/ntops/torch/cos.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import torch

import ntops
from ntops.torch.utils import _cached_make


def cos(input, *, out=None):
if out is None:
out = torch.empty_like(input)

kernel = _cached_make(ntops.kernels.cos.premake, input.ndim)

kernel(input, out)

return out
15 changes: 15 additions & 0 deletions src/ntops/torch/div.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import torch

import ntops
from ntops.torch.utils import _cached_make


def div(input, other, *, rounding_mode=None, out=None):
if out is None:
out = torch.empty_like(input)

kernel = _cached_make(ntops.kernels.div.premake, input.ndim, rounding_mode)

kernel(input, other, out)

return out
27 changes: 27 additions & 0 deletions src/ntops/torch/dropout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import random

import torch

import ntops
from ntops.torch.utils import _cached_make


def dropout(input, p=0.5, training=True, inplace=False):
if not training or p == 0:
if inplace:
return input
else:
return input.clone()

seed = random.randrange(0, 2**31)

if inplace:
output = input
else:
output = torch.empty_like(input)

kernel = _cached_make(ntops.kernels.dropout.premake, input.ndim)

kernel(input, p, seed, output)

return output
15 changes: 15 additions & 0 deletions src/ntops/torch/eq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import torch

import ntops
from ntops.torch.utils import _cached_make


def eq(input, other, *, out=None):
if out is None:
out = torch.empty_like(input)

kernel = _cached_make(ntops.kernels.eq.premake, input.ndim)

kernel(input, other, out)

return out
15 changes: 15 additions & 0 deletions src/ntops/torch/exp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import torch

import ntops
from ntops.torch.utils import _cached_make


def exp(input, *, out=None):
if out is None:
out = torch.empty_like(input)

kernel = _cached_make(ntops.kernels.exp.premake, input.ndim)

kernel(input, out)

return out
15 changes: 15 additions & 0 deletions src/ntops/torch/ge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import torch

import ntops
from ntops.torch.utils import _cached_make


def ge(input, other, *, out=None):
if out is None:
out = torch.empty_like(input)

kernel = _cached_make(ntops.kernels.ge.premake, input.ndim)

kernel(input, other, out)

return out
14 changes: 14 additions & 0 deletions src/ntops/torch/gelu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import torch

import ntops
from ntops.torch.utils import _cached_make


def gelu(input, approximate="none"):
output = torch.empty_like(input)

kernel = _cached_make(ntops.kernels.gelu.premake, input.ndim, approximate)

kernel(input, output)

return output
15 changes: 15 additions & 0 deletions src/ntops/torch/gt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import torch

import ntops
from ntops.torch.utils import _cached_make


def gt(input, other, *, out=None):
if out is None:
out = torch.empty_like(input)

kernel = _cached_make(ntops.kernels.gt.premake, input.ndim)

kernel(input, other, out)

return out
14 changes: 14 additions & 0 deletions src/ntops/torch/isinf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import torch

import ntops
from ntops.torch.utils import _cached_make


def isinf(input):
output = torch.empty_like(input)

kernel = _cached_make(ntops.kernels.isinf.premake, input.ndim)

kernel(input, output)

return output
14 changes: 14 additions & 0 deletions src/ntops/torch/isnan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import torch

import ntops
from ntops.torch.utils import _cached_make


def isnan(input):
output = torch.empty_like(input)

kernel = _cached_make(ntops.kernels.isnan.premake, input.ndim)

kernel(input, output)

return output
Loading