Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 67 additions & 12 deletions pytensor/compile/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -963,24 +963,66 @@ class SymbolicOp(OpFromGraph):
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
if "__props__" in cls.__dict__:
# MetaType installs props-only __hash__ and __eq__ which ignores the inner graph
# override with fgraph-aware version
cls.__hash__ = OpFromGraph.__hash__
cls.__eq__ = OpFromGraph.__eq__
# MetaType installs props-only __hash__/__eq__ that ignore the inner graph.
# Restore the SymbolicOp versions (fgraph-aware, and deferred-op-aware).
cls.__hash__ = SymbolicOp.__hash__
cls.__eq__ = SymbolicOp.__eq__

def __hash__(self):
# A deferred SymbolicOp has no inner graph yet, so identify it by its type,
# props and static params rather than the (absent) frozen fgraph.
if getattr(self, "fgraph", None) is None:
props = tuple(
getattr(self, p) for p in getattr(type(self), "__props__", ())
)
return hash((type(self), props, self.static_params))
return OpFromGraph.__hash__(self)

def __eq__(self, other):
if self is other:
return True
if type(self) is not type(other):
return False
self_built = getattr(self, "fgraph", None) is not None
other_built = getattr(other, "fgraph", None) is not None
if self_built and other_built:
return OpFromGraph.__eq__(self, other)
if self_built != other_built:
return False
props = getattr(type(self), "__props__", ())
return self.static_params == other.static_params and all(
getattr(self, p) == getattr(other, p) for p in props
)

@staticmethod
def filter_inputs(*inputs):
return inputs

def build_static_params(self, inputs):
"""Hashable static information extracted from the actual input *values*.

Some inner graphs depend on input information that is not captured by the
input *types* — most notably the concrete dimensions encoded by a ``size``
vector, which determine the static (broadcastable) shape of the outputs.
Subclasses may override this to return such information from the actual
``inputs``. The returned value is stored as ``self.static_params`` (so it is
available to :meth:`build_inner_graph`) and participates in the decision of
whether the Op must be rebuilt for a new set of inputs.

The default returns ``None`` (the inner graph depends only on input types).
"""
return None

def build_inner_graph(self, *inputs) -> list[Variable]:
raise NotImplementedError

def __init__(self, input_types=None, **kwargs):
def __init__(self, input_types=None, static_params=None, **kwargs):
"""Construct op for the given input Types.

When input_types is None, construction is deferred until the first
__call__, which inspects the actual input types and builds the graph.
"""
self.static_params = static_params
for prop in getattr(type(self), "__props__", ()):
if prop in kwargs:
setattr(self, prop, kwargs.pop(prop))
Expand All @@ -992,15 +1034,28 @@ def __init__(self, input_types=None, **kwargs):
outputs = self.build_inner_graph(*dummy_inputs)
super().__init__(dummy_inputs, outputs, **kwargs)

def __call__(self, *inputs, **kwargs):
inputs = self.filter_inputs(*inputs)
input_types = tuple(inp.type for inp in inputs)

if hasattr(self, "fgraph") and input_types == tuple(self.input_types):
return super().__call__(*inputs, **kwargs)
def _resolve_op(self, inputs) -> SymbolicOp:
"""Return the concrete (built) Op matching the given inputs.

Reuses ``self`` when its inner graph already matches the inputs' types and
static params; otherwise builds a new Op for them.
"""
input_types = tuple(inp.type for inp in inputs)
static_params = self.build_static_params(inputs)
if (
hasattr(self, "fgraph")
and input_types == tuple(self.input_types)
and static_params == self.static_params
):
return self
init_kwargs = dict(self._init_kwargs)
for prop in getattr(type(self), "__props__", ()):
init_kwargs[prop] = getattr(self, prop)
op = type(self)(input_types=list(input_types), **init_kwargs)
return type(self)(
input_types=list(input_types), static_params=static_params, **init_kwargs
)

def __call__(self, *inputs, **kwargs):
inputs = self.filter_inputs(*inputs)
op = self._resolve_op(inputs)
return super(SymbolicOp, op).__call__(*inputs, **kwargs)
55 changes: 36 additions & 19 deletions pytensor/link/numba/dispatch/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -893,36 +893,51 @@ def numba_funcify_Dot(op, node, **kwargs):
if x_dtype == numba_dot_dtype and y_dtype == numba_dot_dtype:

@numba_basic.numba_njit
def dot(x, y):
return np.asarray(np.dot(x, y))
def dot(x, y, out=None):
if out is None:
return np.asarray(np.dot(x, y))
np.dot(x, y, out)
return out

elif x_dtype == numba_dot_dtype and y_dtype != numba_dot_dtype:

@numba_basic.numba_njit
def dot(x, y):
return np.asarray(np.dot(x, y.astype(numba_dot_dtype)))
def dot(x, y, out=None):
if out is None:
return np.asarray(np.dot(x, y.astype(numba_dot_dtype)))
np.dot(x, y.astype(numba_dot_dtype), out)
return out

elif x_dtype != numba_dot_dtype and y_dtype == numba_dot_dtype:

@numba_basic.numba_njit
def dot(x, y):
return np.asarray(np.dot(x.astype(numba_dot_dtype), y))
def dot(x, y, out=None):
if out is None:
return np.asarray(np.dot(x.astype(numba_dot_dtype), y))
np.dot(x.astype(numba_dot_dtype), y, out)
return out

else:

@numba_basic.numba_njit
def dot(x, y):
return np.asarray(
np.dot(x.astype(numba_dot_dtype), y.astype(numba_dot_dtype))
)
def dot(x, y, out=None):
if out is None:
return np.asarray(
np.dot(x.astype(numba_dot_dtype), y.astype(numba_dot_dtype))
)
np.dot(x.astype(numba_dot_dtype), y.astype(numba_dot_dtype), out)
return out

cache_version = 1
cache_version = 2

if out_dtype == numba_dot_dtype:
# np.dot can write straight into the pre-allocated batch output slice.
dot.handles_out = True
return dot, cache_version

else:

# Output needs a dtype cast np.dot can't do in place, so fall back to
# the copying store_core_outputs wrapper.
@numba_basic.numba_njit
def dot_with_cast(x, y):
return dot(x, y).astype(out_dtype)
Expand All @@ -935,14 +950,16 @@ def numba_funcify_BatchedDot(op, node, **kwargs):
dtype = node.outputs[0].type.numpy_dtype

@numba_basic.numba_njit
def batched_dot(x, y):
def batched_dot(x, y, out=None):
# Numba does not support 3D matmul
# https://github.com/numba/numba/issues/3804
shape = x.shape[:-1] + y.shape[2:]
z0 = np.empty(shape, dtype=dtype)
for i in range(z0.shape[0]):
z0[i] = np.dot(x[i], y[i])
if out is None:
shape = x.shape[:-1] + y.shape[2:]
out = np.empty(shape, dtype=dtype)
for i in range(out.shape[0]):
out[i] = np.dot(x[i], y[i])

return z0
return out

return batched_dot
batched_dot.handles_out = True
return batched_dot, 1
22 changes: 17 additions & 5 deletions pytensor/link/numba/dispatch/linalg/constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,13 @@ def numba_funcify_BlockDiagonal(op, node, **kwargs):

The generated code looks something like:

def block_diagonal(arr0, arr1, arr2):
def block_diagonal(arr0, arr1, arr2, out=None):
out_r = arr0.shape[0] + arr1.shape[0] + arr2.shape[0]
out_c = arr0.shape[1] + arr1.shape[1] + arr2.shape[1]
out = np.zeros((out_r, out_c), dtype=np.float64)
if out is None:
out = np.zeros((out_r, out_c), dtype=np.float64)
else:
out[:] = 0

r, c = 0, 0
rr, cc = arr0.shape
Expand All @@ -46,11 +49,18 @@ def block_diagonal(arr0, arr1, arr2):

arg_names = [f"arr{i}" for i in range(n_inp)]
code = [
f"def block_diagonal({', '.join(arg_names)}):",
f"def block_diagonal({', '.join(arg_names)}, out=None):",
CODE_TOKEN.INDENT,
f"out_r = {' + '.join(f'{a}.shape[0]' for a in arg_names)}",
f"out_c = {' + '.join(f'{a}.shape[1]' for a in arg_names)}",
"if out is None:",
CODE_TOKEN.INDENT,
f"out = np.zeros((out_r, out_c), dtype=np.{dtype})",
CODE_TOKEN.DEDENT,
"else:",
CODE_TOKEN.INDENT,
"out[:] = 0",
CODE_TOKEN.DEDENT,
CODE_TOKEN.EMPTY_LINE,
"r, c = 0, 0",
]
Expand All @@ -73,5 +83,7 @@ def block_diagonal(arr0, arr1, arr2):
globals() | {"np": np},
)

cache_version = 1
return numba_basic.numba_njit(block_diag), cache_version
block_diag = numba_basic.numba_njit(block_diag)
block_diag.handles_out = True
cache_version = 2
return block_diag, cache_version
30 changes: 18 additions & 12 deletions pytensor/link/numba/dispatch/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,22 +173,25 @@ def core_MultinomialRV(op, node):
dtype = op.dtype

@numba_basic.numba_njit
def random_fn(rng, n, p):
def random_fn(rng, n, p, out=None):
n_cat = p.shape[0]
draws = np.zeros(n_cat, dtype=dtype)
if out is None:
out = np.empty(n_cat, dtype=dtype)
out[:] = 0
remaining_p = np.float64(1.0)
remaining_n = n
for i in range(n_cat - 1):
draws[i] = rng.binomial(remaining_n, p[i] / remaining_p)
remaining_n -= draws[i]
out[i] = rng.binomial(remaining_n, p[i] / remaining_p)
remaining_n -= out[i]
if remaining_n <= 0:
break
remaining_p -= p[i]
if remaining_n > 0:
draws[n_cat - 1] = remaining_n
return draws
out[n_cat - 1] = remaining_n
return out

return random_fn
random_fn.handles_out = True
return random_fn, 1


@numba_core_rv_funcify.register(ptr.MvNormalRV)
Expand Down Expand Up @@ -220,13 +223,16 @@ def core_DirichletRV(op, node):
dtype = op.dtype

@numba_basic.numba_njit
def random_fn(rng, alpha):
y = np.empty_like(alpha, dtype=dtype)
def random_fn(rng, alpha, out=None):
if out is None:
out = np.empty_like(alpha, dtype=dtype)
for i in range(len(alpha)):
y[i] = rng.gamma(alpha[i], 1.0)
return y / y.sum()
out[i] = rng.gamma(alpha[i], 1.0)
out /= out.sum()
return out

return random_fn, 1
random_fn.handles_out = True
return random_fn, 2


@numba_core_rv_funcify.register(ptr.GumbelRV)
Expand Down
Loading
Loading