-
Notifications
You must be signed in to change notification settings - Fork 159
Refactor advanced subtensor #1756
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
71d5e35
5850fcc
c3a81e2
006b738
fe1188b
97ed8ab
5d47e62
9148eb7
aee4873
a88c68b
58a932c
3f77d83
5f60074
f94da4a
eba99ae
525003c
a2065c2
e89da66
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
|
|
@@ -9,7 +9,7 @@ | |||
| Subtensor, | ||||
| indices_from_subtensor, | ||||
| ) | ||||
| from pytensor.tensor.type_other import MakeSlice, SliceType | ||||
| from pytensor.tensor.type_other import MakeSlice | ||||
|
|
||||
|
|
||||
| def check_negative_steps(indices): | ||||
|
|
@@ -64,6 +64,7 @@ def makeslice(start, stop, step): | |||
| @pytorch_funcify.register(AdvancedSubtensor) | ||||
| def pytorch_funcify_AdvSubtensor(op, node, **kwargs): | ||||
| def advsubtensor(x, *indices): | ||||
| indices = indices_from_subtensor(indices, op.idx_list) | ||||
| check_negative_steps(indices) | ||||
| return x[indices] | ||||
|
|
||||
|
|
@@ -102,12 +103,14 @@ def inc_subtensor(x, y, *flattened_indices): | |||
| @pytorch_funcify.register(AdvancedIncSubtensor) | ||||
| @pytorch_funcify.register(AdvancedIncSubtensor1) | ||||
| def pytorch_funcify_AdvancedIncSubtensor(op, node, **kwargs): | ||||
| idx_list = op.idx_list | ||||
| inplace = op.inplace | ||||
| ignore_duplicates = getattr(op, "ignore_duplicates", False) | ||||
|
|
||||
| if op.set_instead_of_inc: | ||||
|
|
||||
| def adv_set_subtensor(x, y, *indices): | ||||
| def adv_set_subtensor(x, y, *flattened_indices): | ||||
| indices = indices_from_subtensor(flattened_indices, idx_list) | ||||
| check_negative_steps(indices) | ||||
| if isinstance(op, AdvancedIncSubtensor1): | ||||
| op._check_runtime_broadcasting(node, x, y, indices) | ||||
|
|
@@ -120,7 +123,8 @@ def adv_set_subtensor(x, y, *indices): | |||
|
|
||||
| elif ignore_duplicates: | ||||
|
|
||||
| def adv_inc_subtensor_no_duplicates(x, y, *indices): | ||||
| def adv_inc_subtensor_no_duplicates(x, y, *flattened_indices): | ||||
| indices = indices_from_subtensor(flattened_indices, idx_list) | ||||
| check_negative_steps(indices) | ||||
| if isinstance(op, AdvancedIncSubtensor1): | ||||
| op._check_runtime_broadcasting(node, x, y, indices) | ||||
|
|
@@ -132,13 +136,18 @@ def adv_inc_subtensor_no_duplicates(x, y, *indices): | |||
| return adv_inc_subtensor_no_duplicates | ||||
|
|
||||
| else: | ||||
| if any(isinstance(idx.type, SliceType) for idx in node.inputs[2:]): | ||||
| # Check if we have slice indexing in idx_list | ||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||
| has_slice_indexing = ( | ||||
| any(isinstance(entry, slice) for entry in idx_list) if idx_list else False | ||||
| ) | ||||
| if has_slice_indexing: | ||||
| raise NotImplementedError( | ||||
| "IncSubtensor with potential duplicates indexes and slice indexing not implemented in PyTorch" | ||||
| ) | ||||
|
|
||||
| def adv_inc_subtensor(x, y, *indices): | ||||
| # Not needed because slices aren't supported | ||||
| def adv_inc_subtensor(x, y, *flattened_indices): | ||||
| indices = indices_from_subtensor(flattened_indices, idx_list) | ||||
| # Not needed because slices aren't supported in this path | ||||
| # check_negative_steps(indices) | ||||
| if not inplace: | ||||
| x = x.clone() | ||||
|
|
||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -237,20 +237,22 @@ def is_nd_advanced_idx(idx, dtype) -> bool: | |
| return False | ||
|
|
||
| # Parse indices | ||
| if isinstance(subtensor_op, Subtensor): | ||
| if isinstance(subtensor_op, Subtensor | AdvancedSubtensor): | ||
| indices = indices_from_subtensor(node.inputs[1:], subtensor_op.idx_list) | ||
| else: | ||
| indices = node.inputs[1:] | ||
| # The rewrite doesn't apply if advanced indexing could broadcast the samples (leading to duplicates) | ||
| # Note: For simplicity this also excludes subtensor-related expand_dims (np.newaxis). | ||
| # If we wanted to support that we could rewrite it as subtensor + dimshuffle | ||
| # and make use of the dimshuffle lift rewrite | ||
| # TODO: This rewrite is aborting with dummy indexing dimensions which aren't a problem | ||
| if any( | ||
| is_nd_advanced_idx(idx, integer_dtypes) or isinstance(idx.type, NoneTypeT) | ||
| for idx in indices | ||
| ): | ||
| return False | ||
|
|
||
| # The rewrite doesn't apply if advanced indexing could broadcast the samples (leading to duplicates) | ||
| # Note: For simplicity this also excludes subtensor-related expand_dims (np.newaxis). | ||
| # If we wanted to support that we could rewrite it as subtensor + dimshuffle | ||
| # and make use of the dimshuffle lift rewrite | ||
| # TODO: This rewrite is aborting with dummy indexing dimensions which aren't a problem | ||
| if any( | ||
| is_nd_advanced_idx(idx, integer_dtypes) | ||
| or isinstance(getattr(idx, "type", None), NoneTypeT) | ||
| for idx in indices | ||
| ): | ||
|
Comment on lines
+250
to
+254
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No longer a case now that we took the |
||
| return False | ||
|
|
||
| # Check that indexing does not act on support dims | ||
| batch_ndims = rv_op.batch_ndim(rv_node) | ||
|
|
@@ -269,8 +271,11 @@ def is_nd_advanced_idx(idx, dtype) -> bool: | |
| ) | ||
| for idx in supp_indices: | ||
| if not ( | ||
| isinstance(idx.type, SliceType) | ||
| and all(isinstance(i.type, NoneTypeT) for i in idx.owner.inputs) | ||
| (isinstance(idx, slice) and idx == slice(None)) | ||
| or ( | ||
| isinstance(getattr(idx, "type", None), SliceType) | ||
| and all(isinstance(i.type, NoneTypeT) for i in idx.owner.inputs) | ||
|
Comment on lines
+275
to
+277
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We no longer have SliceType or NoneTypeT in the indices |
||
| ) | ||
| ): | ||
| return False | ||
| n_discarded_idxs = len(supp_indices) | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.