Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 35 additions & 11 deletions monai/transforms/compose.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,31 @@
__all__ = ["Compose", "OneOf", "RandomOrder", "SomeOf", "execute_compose"]


def _inverse_one(
t: InvertibleTransform, data: Any, map_items: bool | int, unpack_items: bool, log_stats: bool | str
) -> Any:
"""Invert a single transform, delegating directly to nested ``Compose`` objects.

When ``t`` is a ``Compose`` instance its own ``inverse()`` is called so that
the child's ``map_items`` setting is respected. For all other invertible
transforms, ``apply_transform`` is used with ``lazy=False``.

Args:
t: The invertible transform to invert.
data: Data to be inverted.
map_items: Whether to map over list/tuple items (forwarded to
``apply_transform`` for non-``Compose`` transforms).
unpack_items: Whether to unpack data as parameters.
log_stats: Logger name or boolean for logging.

Returns:
The inverted data.
"""
if isinstance(t, Compose):
return t.inverse(data)
return apply_transform(t.inverse, data, map_items, unpack_items, lazy=False, log_stats=log_stats)


def execute_compose(
data: NdarrayOrTensor | Sequence[NdarrayOrTensor] | Mapping[Any, NdarrayOrTensor],
transforms: Sequence[Any],
Expand Down Expand Up @@ -315,15 +340,20 @@ def get_index_of_first(self, predicate):
return None

def flatten(self):
"""Return a Composition with a simple list of transforms, as opposed to any nested Compositions.
"""Return a Composition with a flattened list of transforms.

Nested ``Compose`` objects that share the same ``map_items`` setting as
the parent are inlined. Nested ``Compose`` objects with a *different*
``map_items`` value are kept as-is so their item-mapping behaviour is
preserved at runtime and during inversion.

e.g., `t1 = Compose([x, x, x, x, Compose([Compose([x, x]), x, x])]).flatten()`
will result in the equivalent of `t1 = Compose([x, x, x, x, x, x, x, x])`.

"""
new_transforms = []
for t in self.transforms:
if type(t) is Compose: # nopep8
if type(t) is Compose and t.map_items == self.map_items:
new_transforms += t.flatten().transforms
else:
new_transforms.append(t)
Expand Down Expand Up @@ -365,9 +395,7 @@ def inverse(self, data):
)
# loop backwards over transforms
for t in reversed(invertible_transforms):
data = apply_transform(
t.inverse, data, self.map_items, self.unpack_items, lazy=False, log_stats=self.log_stats
)
data = _inverse_one(t, data, self.map_items, self.unpack_items, self.log_stats)
return data

@staticmethod
Expand Down Expand Up @@ -622,9 +650,7 @@ def inverse(self, data):
# loop backwards over transforms
for o in reversed(applied_order):
if isinstance(self.transforms[o], InvertibleTransform):
data = apply_transform(
self.transforms[o].inverse, data, self.map_items, self.unpack_items, log_stats=self.log_stats
)
data = _inverse_one(self.transforms[o], data, self.map_items, self.unpack_items, self.log_stats)
return data


Expand Down Expand Up @@ -789,8 +815,6 @@ def inverse(self, data):
# loop backwards over transforms
for o in reversed(applied_order):
if isinstance(self.transforms[o], InvertibleTransform):
data = apply_transform(
self.transforms[o].inverse, data, self.map_items, self.unpack_items, log_stats=self.log_stats
)
data = _inverse_one(self.transforms[o], data, self.map_items, self.unpack_items, self.log_stats)

return data
11 changes: 7 additions & 4 deletions monai/transforms/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,10 +143,13 @@ def apply_transform(
try:
map_items_ = int(map_items) if isinstance(map_items, bool) else map_items
if isinstance(data, (list, tuple)) and map_items_ > 0 and not isinstance(transform, ReduceTrait):
return [
apply_transform(transform, item, map_items_ - 1, unpack_items, log_stats, lazy, overrides)
for item in data
]
# If the transform is a Compose with its own map_items, let it handle list/tuple
# expansion internally so that nested Compose map_items settings are respected.
if not isinstance(transform, transforms.compose.Compose):
return [
apply_transform(transform, item, map_items_ - 1, unpack_items, log_stats, lazy, overrides)
for item in data
]
return _apply_transform(transform, data, unpack_items, lazy, overrides, log_stats)
except Exception as e:
# if in debug mode, don't swallow exception so that the breakpoint
Expand Down
50 changes: 50 additions & 0 deletions tests/transforms/compose/test_compose.py
Original file line number Diff line number Diff line change
Expand Up @@ -775,6 +775,56 @@ def test_compose_execute_equivalence_with_flags(self, flags, data, pipeline):
self.assertEqual(expected, actual)


class TestNestedComposeMapItems(unittest.TestCase):
"""Tests for nested Compose respecting child map_items (issues #7932, #7565)."""

def test_child_map_items_false_receives_list(self):
"""Parent map_items=True, child map_items=False: child receives list as-is."""

def split(x):
return [x + 1, x + 2]

def sum_list(items):
return sum(items)

# The child Compose(map_items=False) should receive the list from split()
# and pass it as-is to sum_list, rather than the parent expanding the list.
pipeline = mt.Compose([split, mt.Compose([sum_list], map_items=False)])
result = pipeline(10)
self.assertEqual(result, 23) # (10+1) + (10+2) = 23

def test_inverse_respects_child_map_items(self):
"""Inverse path should delegate to child Compose.inverse directly."""
pipeline = mt.Compose([mt.Flip(0), mt.Compose([mt.Flip(1)], map_items=False)])
data = torch.randn(1, 4, 4)
result = pipeline(data)
restored = pipeline.inverse(result)
torch.testing.assert_close(data, restored)

def test_parent_no_map_child_map(self):
"""Parent map_items=False, child map_items=True: child maps over items."""

def double(x):
return x * 2

# Parent treats the list as a single value; child maps double() over each item.
pipeline = mt.Compose([mt.Compose([double], map_items=True)], map_items=False)
result = pipeline([1, 2, 3])
self.assertEqual(result, [2, 4, 6])

def test_flatten_preserves_different_map_items(self):
"""flatten() should not merge a child Compose with different map_items."""

def noop(x):
return x

parent = mt.Compose([noop, mt.Compose([noop, noop], map_items=False), noop])
flat = parent.flatten()
# The inner Compose(map_items=False) should NOT be flattened
self.assertEqual(len(flat.transforms), 3)
self.assertIsInstance(flat.transforms[1], mt.Compose)


class TestComposeCallableInput(unittest.TestCase):

def test_value_error_when_not_sequence(self):
Expand Down
Loading