Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 34 additions & 30 deletions pytensor/tensor/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -787,6 +787,34 @@ def __hash__(self):
)
return hash((type(self), props_values))

@staticmethod
def str_from_slice(entry):
if entry.step is not None:
return ":".join(
(
"start" if entry.start is not None else "",
"stop" if entry.stop is not None else "",
"step",
)
)
if entry.stop is not None:
return f"{'start' if entry.start is not None else ''}:stop"
if entry.start is not None:
return "start:"
return ":"

@staticmethod
def str_from_indices(idx_list):
indices = []
letter_indexes = 0
for entry in idx_list:
if isinstance(entry, slice):
indices.append(BaseSubtensor.str_from_slice(entry))
else:
indices.append("ijk"[letter_indexes % 3] * (letter_indexes // 3 + 1))
letter_indexes += 1
return ", ".join(indices)


class Subtensor(BaseSubtensor, COp):
"""Basic NumPy indexing operator."""
Expand Down Expand Up @@ -907,34 +935,6 @@ def connection_pattern(self, node):

return rval

@staticmethod
def str_from_slice(entry):
if entry.step is not None:
return ":".join(
(
"start" if entry.start is not None else "",
"stop" if entry.stop is not None else "",
"step",
)
)
if entry.stop is not None:
return f"{'start' if entry.start is not None else ''}:stop"
if entry.start is not None:
return "start:"
return ":"

@staticmethod
def str_from_indices(idx_list):
indices = []
letter_indexes = 0
for entry in idx_list:
if isinstance(entry, slice):
indices.append(Subtensor.str_from_slice(entry))
else:
indices.append("ijk"[letter_indexes % 3] * (letter_indexes // 3 + 1))
letter_indexes += 1
return ", ".join(indices)

def __str__(self):
return f"{self.__class__.__name__}{{{self.str_from_indices(self.idx_list)}}}"

Expand Down Expand Up @@ -1407,7 +1407,7 @@ def __init__(

def __str__(self):
name = "SetSubtensor" if self.set_instead_of_inc else "IncSubtensor"
return f"{name}{{{Subtensor.str_from_indices(self.idx_list)}}}"
return f"{name}{{{super().str_from_indices(self.idx_list)}}}"

def make_node(self, x, y, *inputs):
"""
Expand Down Expand Up @@ -2275,6 +2275,9 @@ class AdvancedSubtensor(BaseSubtensor, COp):
__props__ = ("idx_list",)
__hash__ = BaseSubtensor.__hash__

def __str__(self):
return f"{self.__class__.__name__}{{{super().str_from_indices(self.idx_list)}}}"

def c_code_cache_version(self):
hv = Subtensor.helper_c_code_cache_version()
if hv:
Expand Down Expand Up @@ -2576,11 +2579,12 @@ def __init__(
self.ignore_duplicates = ignore_duplicates

def __str__(self):
return (
name = (
"AdvancedSetSubtensor"
if self.set_instead_of_inc
else "AdvancedIncSubtensor"
)
return f"{name}{{{super().str_from_indices(self.idx_list)}}}"

def make_node(self, x, y, *index_variables):
if len(index_variables) != self.n_index_vars:
Expand Down
2 changes: 1 addition & 1 deletion tests/link/numba/test_subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,7 @@ def test_AdvancedIncSubtensor(
with (
pytest.warns(
UserWarning,
match="Numba will use object mode to run AdvancedIncSubtensor's perform method",
match="Numba will use object mode to run AdvancedIncSubtensor",
)
if duplicate_indices_require_obj_mode
else contextlib.nullcontext()
Expand Down