Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 17 additions & 10 deletions pytensor/link/numba/dispatch/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def _filter_numba_warnings():


def numba_njit(
*args, fastmath=None, final_function: bool = False, **kwargs
*args, fastmath=None, inline=None, final_function: bool = False, **kwargs
) -> Callable:
"""A thin wrapper around `numba.njit`.

Expand Down Expand Up @@ -88,6 +88,9 @@ def numba_njit(
kwargs.setdefault("no_cpython_wrapper", True)
kwargs.setdefault("no_cfunc_wrapper", True)

if inline is not None:
kwargs["inline"] = inline

if len(args) > 0 and callable(args[0]):
return _njit(*args[1:], fastmath=fastmath, **kwargs)(args[0]) # type: ignore
else:
Expand Down Expand Up @@ -448,15 +451,19 @@ def numba_funcify_ensure_cache(op, *args, **kwargs) -> tuple[Callable, str | Non
if config.numba__cache and config.compiler_verbose:
print(f"{op} of type {type(op)} will not be cached by PyTensor.\n") # noqa: T201
return jitable_func, None
else:
op_name = jitable_func.__name__
cached_func = compile_numba_function_src(
src=f"def {op_name}(*args): return jitable_func(*args)",
function_name=op_name,
global_env=globals() | {"jitable_func": jitable_func},
cache_key=f"{cache_key}_fastmath{int(config.numba__fastmath)}",
)
return numba_njit(cached_func, cache=True), cache_key

# Inline functions get baked into the caller's cache entry and can't be independently cached
if getattr(jitable_func, "targetoptions", {}).get("inline") == "always":
return jitable_func, cache_key

op_name = jitable_func.__name__
cached_func = compile_numba_function_src(
src=f"def {op_name}(*args): return jitable_func(*args)",
function_name=op_name,
global_env=globals() | {"jitable_func": jitable_func},
cache_key=f"{cache_key}_fastmath{int(config.numba__fastmath)}",
)
return numba_njit(cached_func, cache=True), cache_key


def cache_key_for_constant(data):
Expand Down
6 changes: 3 additions & 3 deletions pytensor/link/numba/dispatch/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,7 @@ def numba_funcify_DimShuffle(op: DimShuffle, node, **kwargs):
if new_order == ():
# Special case needed because of https://github.com/numba/numba/issues/9933

@numba_basic.numba_njit
@numba_basic.numba_njit(inline="always")
def squeeze_to_0d(x):
if not x.size == 1:
raise ValueError(
Expand All @@ -473,13 +473,13 @@ def squeeze_to_0d(x):
new_shape = shape_template
new_strides = strides_template

@numba_basic.numba_njit
@numba_basic.numba_njit(inline="always")
def dimshuffle(x):
return as_strided(np.asarray(x), shape=new_shape, strides=new_strides)

else:

@numba_basic.numba_njit
@numba_basic.numba_njit(inline="always")
def dimshuffle(x):
old_shape = x.shape
old_strides = x.strides
Expand Down
46 changes: 23 additions & 23 deletions pytensor/link/numba/dispatch/scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,12 +136,12 @@ def {scalar_op_fn_name}({input_signature}):

# Functions that call a function pointer can't be cached
cache_key = None if cython_func else scalar_op_cache_key(op)
return numba_basic.numba_njit(scalar_op_fn), cache_key
return numba_basic.numba_njit(scalar_op_fn, inline="always"), cache_key


@register_funcify_and_cache_key(Switch)
def numba_funcify_Switch(op, node, **kwargs):
@numba_basic.numba_njit
@numba_basic.numba_njit(inline="always")
def switch(condition, x, y):
if condition:
return x
Expand Down Expand Up @@ -174,34 +174,34 @@ def numba_funcify_Pow(op, node, **kwargs):
def pow(x, y):
return x**y

# Numba power fails when exponents are discrete integers and fasthmath=True
# https://github.com/numba/numba/issues/9554
fastmath = False if np.dtype(pow_dtype).kind in "ibu" else None

return numba_basic.numba_njit(pow, fastmath=fastmath), scalar_op_cache_key(
op, cache_version=1
)
# Integer exponents break fastmath and inline (numba#9554)
integer_exp = np.dtype(pow_dtype).kind in "ibu"
return numba_basic.numba_njit(
pow,
fastmath=False if integer_exp else None,
inline=None if integer_exp else "always",
), scalar_op_cache_key(op, cache_version=1)


@register_funcify_and_cache_key(Add)
def numba_funcify_Add(op, node, **kwargs):
nary_add_fn = binary_to_nary_func(node.inputs, "add", "+")

return numba_basic.numba_njit(nary_add_fn), scalar_op_cache_key(op)
return numba_basic.numba_njit(nary_add_fn, inline="always"), scalar_op_cache_key(op)


@register_funcify_and_cache_key(Mul)
def numba_funcify_Mul(op, node, **kwargs):
nary_mul_fn = binary_to_nary_func(node.inputs, "mul", "*")

return numba_basic.numba_njit(nary_mul_fn), scalar_op_cache_key(op)
return numba_basic.numba_njit(nary_mul_fn, inline="always"), scalar_op_cache_key(op)


@register_funcify_and_cache_key(Cast)
def numba_funcify_Cast(op, node, **kwargs):
dtype = np.dtype(op.o_type.dtype)

@numba_basic.numba_njit
@numba_basic.numba_njit(inline="always")
def cast(x):
return numba_basic.direct_cast(x, dtype)

Expand All @@ -210,7 +210,7 @@ def cast(x):

@register_funcify_and_cache_key(Identity)
def numba_funcify_type_casting(op, **kwargs):
@numba_basic.numba_njit
@numba_basic.numba_njit(inline="always")
def identity(x):
return x

Expand All @@ -219,7 +219,7 @@ def identity(x):

@register_funcify_and_cache_key(Clip)
def numba_funcify_Clip(op, **kwargs):
@numba_basic.numba_njit
@numba_basic.numba_njit(inline="always")
def clip(x, min_val, max_val):
if x < min_val:
return min_val
Expand Down Expand Up @@ -247,7 +247,7 @@ def numba_funcify_Composite(op, node, **kwargs):

@register_funcify_and_cache_key(Second)
def numba_funcify_Second(op, node, **kwargs):
@numba_basic.numba_njit
@numba_basic.numba_njit(inline="always")
def second(x, y):
return y

Expand All @@ -256,7 +256,7 @@ def second(x, y):

@register_funcify_and_cache_key(Reciprocal)
def numba_funcify_Reciprocal(op, node, **kwargs):
@numba_basic.numba_njit
@numba_basic.numba_njit(inline="always")
def reciprocal(x):
# This is how the C-backend implementation works
return np.divide(np.float32(1.0), x)
Expand All @@ -275,15 +275,15 @@ def numba_funcify_Sigmoid(op, node, **kwargs):
"uint64": np.float64,
}[inp_dtype]

@numba_basic.numba_njit
@numba_basic.numba_njit(inline="always")
def sigmoid(x):
# Can't negate uint
float_x = numba_basic.direct_cast(x, upcast_uint_dtype)
return 1 / (1 + np.exp(-float_x))

else:

@numba_basic.numba_njit
@numba_basic.numba_njit(inline="always")
def sigmoid(x):
return 1 / (1 + np.exp(-x))

Expand All @@ -292,7 +292,7 @@ def sigmoid(x):

@register_funcify_and_cache_key(GammaLn)
def numba_funcify_GammaLn(op, node, **kwargs):
@numba_basic.numba_njit
@numba_basic.numba_njit(inline="always")
def gammaln(x):
return math.lgamma(x)

Expand All @@ -301,7 +301,7 @@ def gammaln(x):

@register_funcify_and_cache_key(Log1mexp)
def numba_funcify_Log1mexp(op, node, **kwargs):
@numba_basic.numba_njit
@numba_basic.numba_njit(inline="always")
def logp1mexp(x):
if x < np.log(0.5):
return np.log1p(-np.exp(x))
Expand All @@ -317,7 +317,7 @@ def numba_funcify_Erf(op, node, **kwargs):
# Complex not supported by numba
return numba_funcify_ScalarOp(op, node=node, **kwargs)

@numba_basic.numba_njit
@numba_basic.numba_njit(inline="always")
def erf(x):
return math.erf(x)

Expand All @@ -326,7 +326,7 @@ def erf(x):

@register_funcify_and_cache_key(Erfc)
def numba_funcify_Erfc(op, **kwargs):
@numba_basic.numba_njit
@numba_basic.numba_njit(inline="always")
def erfc(x):
return math.erfc(x)

Expand All @@ -347,7 +347,7 @@ def numba_funcify_Softplus(op, node, **kwargs):
upcast_uint_dtype = None
out_dtype = np.dtype(node.outputs[0].type.dtype)

@numba_basic.numba_njit
@numba_basic.numba_njit(inline="always")
def softplus(x):
if x < -37.0:
value = np.exp(x)
Expand Down
12 changes: 6 additions & 6 deletions pytensor/link/numba/dispatch/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

@register_funcify_default_op_cache_key(Shape)
def numba_funcify_Shape(op, **kwargs):
@numba_basic.numba_njit
@numba_basic.numba_njit(inline="always")
def shape(x):
return np.asarray(np.shape(x))

Expand All @@ -23,7 +23,7 @@ def shape(x):
def numba_funcify_Shape_i(op, **kwargs):
i = op.i

@numba_basic.numba_njit
@numba_basic.numba_njit(inline="always")
def shape_i(x):
return np.asarray(np.shape(x)[i])

Expand All @@ -36,7 +36,7 @@ def numba_funcify_SpecifyShape(op, node, **kwargs):
shape_input_names = ["shape_" + str(i) for i in range(len(shape_inputs))]

func_conditions = [
f"assert x.shape[{i}] == {eval_dim_name}, f'SpecifyShape: dim {{{i}}} of input has shape {{x.shape[{i}]}}, expected {{{eval_dim_name}.item()}}.'"
f"assert x.shape[{i}] == {eval_dim_name}, 'SpecifyShape: shape mismatch in dim {i}'"
for i, (node_dim_input, eval_dim_name) in enumerate(
zip(shape_inputs, shape_input_names, strict=True)
)
Expand All @@ -52,7 +52,7 @@ def specify_shape(x, {", ".join(shape_input_names)}):
)

specify_shape = compile_function_src(func, "specify_shape", globals())
return numba_basic.numba_njit(specify_shape)
return numba_basic.numba_njit(specify_shape, inline="always")


@register_funcify_default_op_cache_key(Reshape)
Expand All @@ -61,13 +61,13 @@ def numba_funcify_Reshape(op, **kwargs):

if ndim == 0:

@numba_basic.numba_njit
@numba_basic.numba_njit(inline="always")
def reshape(x, shape):
return np.asarray(x.item())

else:

@numba_basic.numba_njit
@numba_basic.numba_njit(inline="always")
def reshape(x, shape):
# TODO: Use this until https://github.com/numba/numba/issues/7353 is closed.
return np.reshape(
Expand Down
10 changes: 6 additions & 4 deletions pytensor/link/numba/dispatch/tensor_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def allocempty({", ".join(shape_var_names)}):
alloc_def_src, "allocempty", globals() | {"np": np, "dtype": np.dtype(op.dtype)}
)

return numba_basic.numba_njit(alloc_fn)
return numba_basic.numba_njit(alloc_fn, inline="always")


@register_funcify_and_cache_key(Alloc)
Expand Down Expand Up @@ -221,12 +221,14 @@ def makevector({", ".join(input_names)}):
globals() | {"np": np, "dtype": dtype},
)

return numba_basic.numba_njit(makevector_fn)
# Numba can't inline closures with more than 30 arguments
inline = "always" if len(input_names) <= 30 else None
return numba_basic.numba_njit(makevector_fn, inline=inline)


@register_funcify_default_op_cache_key(TensorFromScalar)
def numba_funcify_TensorFromScalar(op, **kwargs):
@numba_basic.numba_njit
@numba_basic.numba_njit(inline="always")
def tensor_from_scalar(x):
return np.array(x)

Expand All @@ -235,7 +237,7 @@ def tensor_from_scalar(x):

@register_funcify_default_op_cache_key(ScalarFromTensor)
def numba_funcify_ScalarFromTensor(op, **kwargs):
@numba_basic.numba_njit
@numba_basic.numba_njit(inline="always")
def scalar_from_tensor(x):
return x.item()

Expand Down
4 changes: 3 additions & 1 deletion pytensor/link/numba/dispatch/vectorize_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,9 @@ def store_core_outputs({inp_signature}, {out_signature}):
"store_core_outputs",
{**globals(), **global_env},
)
return numba_basic.numba_njit(func)
# Numba can't inline closures with more than 30 arguments
inline = "always" if nin + nout <= 30 else None
return numba_basic.numba_njit(func, inline=inline)


_jit_options = {
Expand Down
Loading
Loading