Skip to content

Confusion between grad_undefined / grad_disconnected #1827

@ricardoV94

Description

@ricardoV94

Description

The docs for gradient https://pytensor.readthedocs.io/en/latest/extending/op.html#gradient mention a couple of special return types that are available for L_op:

  • grad_not_implemented and grad_undefined (both are NullType())
  • grad_disconnected (there's no helper called like this, it's obtained by calling disconnected_type())
  • float zeros (TensorType)

In #1806 there was some discussion on the meaning of these gradients. It focused a bit on integer inputs, and the difference between mul(int_tensor, int_tensor), mul(float_tensor, int_tensor) and specify_shape(float_tensor, shape=int_tensor). From the docs and the codebase I think the following behaviors for gradient wrt to the integer inputs are consistent:

  • mul(int_tensor, int_tensor) -> zeros_like(int_tensor, dtype=float)
  • mul(int_tensor, float_tensor) -> float_tensor
  • specify_shape(float_tensor, shape=int_tensor) -> grad_disconnected

So far so good. The confusion starts with grad_undefined vs grad_disconnected. My summary of the docs are:

  • If a fractional argument doesn't make sense the gradient is undefined.
  • If the values (as opposed do dtype or shape, or other non numerical meta-info) of an input don't affect the values (as opposed to the shape, dtype or other non numerical meta-info) of the output, the gradient is disconnected.
    • There's a second case where this shows up: If an op has unused inputs, or the inputs are only used in the computation of a subset of the outputs, which are disconnected when L_op is called, the inputs are also disconnected.

A gradient may be both undefined and disconnected, in which case disconnected wins? This is not clearly stated, but I think it's the logical conclusion. In specify_shape the specified shapes don't make sense as fractions.

But PyTensor is inconsistent. The confusion in #1806 came from using the helper io_pattern_connection that should return the "grad connected-ness" of some inputs to some outputs. It concludes that the split_size (and axis) argument in Split are connected to out:

from pytensor.graph.op import io_connection_pattern
import pytensor.tensor as pt

x = pt.vector("x")
d0 = pt.scalar("d0", dtype=int)
d1 = pt.scalar("d1", dtype=int)
x0, x1 = pt.split(x, [d0, d1], axis=0)
out = x0.sum() + x1.sum()

io_connection_pattern([x, d0, d1], [out])  # [[True], [True], [True]]

But this is logically equivalent to the same graph (that has no symbolic axis now):

x0, x1 = x[:d0], x[d0:d1]
out = x0.sum() + x1.sum()
io_connection_pattern([x, d0, d1], [out])  # [[True], [False], [False]]

Where it concludes something else.

I agree that they are all undefined, but the question of connected is a fuzzier one.

Is the end of a slice connected to the values? It certainly determines which numbers will show up, but it doesn't influence their "value"? For this reason the stop argument in Arange in disconnected, while start and step are connected:

def L_op(self, inputs, outputs, grads):
start, _stop, step = inputs
(gz,) = grads
# `start` and `step` affect the output values
# but the outputs are integers so there's
# no gradient through them.
# When they are not integers, the gradients are
# as expressed below.
# `stop` does not affect the output values,
# just the output shape, so it is disconnected.
if self.dtype in discrete_dtypes:
return [
start.zeros_like(dtype=config.floatX),
DisconnectedType()(),
step.zeros_like(dtype=config.floatX),
]
else:
num_steps_taken = outputs[0].shape[0]
return [
gz.sum(),
DisconnectedType()(),
(gz * arange(num_steps_taken, dtype=self.dtype)).sum(),
]

Correction to the comment: The outputs aren't necessarily integers, but it's still not a continuous function

By this logic the stop in a Slice index doesn't influence the contents, nor should it in Split.

Other cases

Is the shape argument in Alloc connected? PyTensor says no, and it does make sense to me. It doesn't influence the values of the output, only how many there are.

But what about the axis argument (which is the one actually talked about in the docs). Maybe I would argue it influences the values in an operation like sum, but not in an operation like split/join. It only influences where they show up? I could also be convinced it doesn't influence the values at all.

Note: I'm not too worried about symbolic axis, I think we should move away from it as it doesn't really buy as much and makes graph analysis harder (#1528).

Current uses of grad_undefined

Here is a list of the current uses of grad_undefined in the codebase:

  • PolyGamma: gradient order (makes sense to me)
  • Scan: for untraceable sit-sot
  • Sparse GetItemList, GetItemList2: index inputs (doesn't make sense if it's disconnected in Subtensor)
  • Nonzero
  • Eye: all inputs (Doesn't make sense for m,n, if it is disconnected in Alloc. Doesn't make sense for offset if it's disconnected in set_subtensor)
  • Join: axis
  • Split: axis and split_sizes (Doesn't make sense for split_sizes)
  • PermuteRowElements: permutation indices (Doesn't make sense if it is disconnected for Subtensor)
  • FillDiagonalOffset: offset (Doesn't make sense if axis in set_subtensor is disconnected)
  • Sort: axis
  • ArgSort: axis
  • RVs: everything (Makes sense barring Stochastic gradients in pytensor #1419)

Taking a step back

Note these are all just typed errors in the end, to explain to the developer/user why a certain gradient can't be taken: not implemented, disconnected from cost, not well defined mathematically. Docstrings confirm this:

class DisconnectedType(Type):
"""A type indicating that a variable is the result of taking the gradient of
``c`` with respect to ``x`` when ``c`` is not a function of ``x``.
It serves as a symbolic placeholder for ``0``, but conveys the extra
information that this gradient is ``0`` because it is disconnected.
"""

The practical issue in #1806 is that io_connection_pattern only cares about disconnected_grad, and split is doing the wrong thing in my opinion (fixed in #1828). The even more practical issue is that io_connection_pattern is maybe not what we needed (or all we needed).

The long-term issue is we want a clear API, that is useful for us. Maybe the distinction between disconnected_grad and undefined_grad is not the right one, and it doesn't make sense for the L_op to have to prioritize one over the other. Perhaps it should return a type-union, as a gradient may be both disconnected and undefined, or only one of the two.

Example:

def f(x, y):
  return x + 1, polygamma(y, 1)

grad(f(x, y)[0], wrt=x)  # connected and defined
grad(f(x, y)[1], wrt=x)  # disconnected (would be defined if connected)
grad(f(x, y)[0], wrt=y)  # disconnected (would be undefined if connected)
grad(f(x, y)[1], wrt=y)  # connected and undefined

How does PyTensor use disconnected_grad

pytensor.gradient.grad skips calling L_op altogether if all outputs are disconnected, or if op.connection_pattern states that none of the inputs are connected to the remaining outputs.

pytensor/pytensor/gradient.py

Lines 1217 to 1237 in 90c6f98

if not any(inputs_connected):
# All outputs of this op are disconnected so we can skip
# Calling the op's grad method and report that the inputs
# are disconnected
# (The op's grad method could do this too, but this saves the
# implementer the trouble of worrying about this case)
input_grads = [disconnected_type() for ipt in inputs]
elif all(only_connected_to_nan):
# All inputs are only connected to nan gradients, so we don't
# need to bother calling the grad method. We know the gradient
# with respect to all connected inputs is nan.
input_grads = []
for connected in inputs_connected:
if connected:
input_grads.append(null_type())
else:
input_grads.append(disconnected_type())
else:
# At least one input of this op is connected to the cost so and
# not all output gradients are undefined so we must
# call the op's grad method

It also warns if the disconnected pattern wasn't advertised or kept:

pytensor/pytensor/gradient.py

Lines 1452 to 1478 in 90c6f98

# Check that op.connection_pattern matches the connectivity
# logic driving the op.grad method
for i, (ipt, ig, connected) in enumerate(
zip(inputs, input_grads, inputs_connected, strict=True)
):
actually_connected = not isinstance(ig.type, DisconnectedType)
if actually_connected and not connected:
msg = (
f"{node.op}.grad returned {ig} of type {ig.type} for input {i}."
" Expected DisconnectedType instance based on "
" the output of the op's connection_pattern "
"method."
)
raise TypeError(msg)
elif connected and not actually_connected:
msg = f"{node.op}.grad returned DisconnectedType for input {i}."
if hasattr(node.op, "connection_pattern"):
msg += " Its connection_pattern method does not allow this."
raise TypeError(msg)
else:
msg += (
" You may want to implement a "
"connection_pattern method for it."
)
warnings.warn(msg)

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions