Skip to content

Fix nested Compose map_items in forward and inverse paths#8787

Open
aymuos15 wants to merge 5 commits intoProject-MONAI:devfrom
aymuos15:worktree-fix-nested-compose-map-items
Open

Fix nested Compose map_items in forward and inverse paths#8787
aymuos15 wants to merge 5 commits intoProject-MONAI:devfrom
aymuos15:worktree-fix-nested-compose-map-items

Conversation

@aymuos15
Copy link
Contributor

@aymuos15 aymuos15 commented Mar 20, 2026

Summary

Fixes #7932, #7565

When a child Compose has a different map_items setting than its parent, the parent's apply_transform would expand list/tuple data before the child ever sees it — silently overriding the child's map_items.

This PR makes three coordinated changes so the child's map_items is respected:

  • Forward path (apply_transform): Skip list expansion when the transform is a Compose instance, letting it handle expansion via its own map_items in execute_compose.
  • Inverse path (_inverse_one helper): Delegate directly to Compose.inverse() for nested Compose objects (including RandomOrder and SomeOf) instead of routing through apply_transform(t.inverse, ...).
  • flatten(): Only inline nested Compose objects that share the same map_items as the parent. Children with a different map_items are preserved as-is.

Test plan

  • test_child_map_items_false_receives_list — parent map_items=True, child map_items=False: child receives list as-is
  • test_inverse_respects_child_map_items — inverse roundtrip with nested Compose
  • test_parent_no_map_child_map — parent map_items=False, child map_items=True: child maps over items
  • test_flatten_preserves_different_map_itemsflatten() does not merge children with different map_items

…t-MONAI#7932, Project-MONAI#7565)

When a child Compose has a different map_items setting than its parent,
the parent now delegates to the child instead of expanding list/tuple
data itself. This applies to forward execution (apply_transform),
flatten(), and the inverse path in Compose, RandomOrder, and SomeOf.

Signed-off-by: Soumya Snigdha Kundu <soumya_snigdha.kundu@kcl.ac.uk>
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 20, 2026

📝 Walkthrough

Walkthrough

Introduced a module helper _inverse_one(...) to centralize per-transform inversion. Compose.inverse, RandomOrder.inverse, and SomeOf.inverse now use _inverse_one for inverting contained transforms. Compose.flatten() was changed to inline nested Compose instances only when type(t) is Compose and t.map_items == self.map_items, preserving nested Compose with differing map_items. apply_transform() was adjusted so a Compose transform can manage list/tuple expansion itself rather than having the caller expand items. New tests cover nested Compose map_items behavior, inversion, and flattening.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 50.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed Title clearly and specifically describes the main change: fixing nested Compose map_items behavior in both forward and inverse paths.
Description check ✅ Passed PR description covers all key aspects: issues fixed, explanation of three coordinated changes, and test plan with checkmarks. Template sections mostly complete.
Linked Issues check ✅ Passed PR directly addresses issue #7932 by respecting child Compose map_items settings in forward, inverse, and flatten paths—the core feature request.
Out of Scope Changes check ✅ Passed All changes (apply_transform, _inverse_one helper, flatten logic, and tests) are tightly scoped to fixing nested Compose map_items handling.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
📝 Coding Plan
  • Generate coding plan for human review comments

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
monai/transforms/compose.py (1)

40-51: Add Google-style docstring with Args/Returns.

Per coding guidelines, docstrings should describe parameters and return values.

📝 Proposed docstring
 def _inverse_one(
     t: InvertibleTransform,
     data: Any,
     map_items: bool | int,
     unpack_items: bool,
     log_stats: bool | str,
 ) -> Any:
-    """Invert a single transform, delegating directly to nested ``Compose`` objects."""
+    """Invert a single transform, delegating directly to nested ``Compose`` objects.
+
+    Args:
+        t: The invertible transform to invert.
+        data: Data to be inverted.
+        map_items: Whether to map over list/tuple items.
+        unpack_items: Whether to unpack data as parameters.
+        log_stats: Logger name or boolean for logging.
+
+    Returns:
+        The inverted data.
+    """
     if isinstance(t, Compose):
         return t.inverse(data)
     return apply_transform(t.inverse, data, map_items, unpack_items, lazy=False, log_stats=log_stats)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@monai/transforms/compose.py` around lines 40 - 51, Add a Google-style
docstring to the _inverse_one function describing its purpose, parameters and
return value: state that it inverts a single transform (delegating to
Compose.inverse when t is a Compose), document args t (InvertibleTransform),
data (Any), map_items (bool|int), unpack_items (bool), log_stats (bool|str), and
explain the return value (inverted data of type Any); also mention that
apply_transform is called for non-Compose transforms with lazy=False. Include
short notes on side effects (calls to Compose.inverse/apply_transform) if
relevant.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@monai/transforms/compose.py`:
- Around line 40-51: Add a Google-style docstring to the _inverse_one function
describing its purpose, parameters and return value: state that it inverts a
single transform (delegating to Compose.inverse when t is a Compose), document
args t (InvertibleTransform), data (Any), map_items (bool|int), unpack_items
(bool), log_stats (bool|str), and explain the return value (inverted data of
type Any); also mention that apply_transform is called for non-Compose
transforms with lazy=False. Include short notes on side effects (calls to
Compose.inverse/apply_transform) if relevant.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 13d05d32-401c-4be3-a001-687e58886c54

📥 Commits

Reviewing files that changed from the base of the PR and between daaedaa and 26c051f.

📒 Files selected for processing (3)
  • monai/transforms/compose.py
  • monai/transforms/transform.py
  • tests/transforms/compose/test_compose.py

Address CodeRabbit review comment: document Args/Returns for the
_inverse_one helper function per project coding guidelines.

Signed-off-by: Soumya Snigdha Kundu <soumya_snigdha.kundu@kcl.ac.uk>
Signed-off-by: Soumya Snigdha Kundu <soumya_snigdha.kundu@kcl.ac.uk>
Signed-off-by: Soumya Snigdha Kundu <soumya_snigdha.kundu@kcl.ac.uk>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
monai/transforms/compose.py (1)

354-361: flatten() doesn't preserve parent attributes (unlike OneOf.flatten()).

Line 361 returns Compose(new_transforms) with defaults. If someone calls flatten() expecting an equivalent executable pipeline, map_items, unpack_items, log_stats, lazy, and overrides are lost.

Current internal usage only reads .transforms, so no bug today.

Proposed fix
-        return Compose(new_transforms)
+        return Compose(new_transforms, self.map_items, self.unpack_items, self.log_stats, self.lazy, self.overrides)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@monai/transforms/compose.py` around lines 354 - 361, Compose.flatten()
currently builds and returns a new Compose instance with only new_transforms,
dropping parent attributes; update Compose.flatten to preserve and pass through
the parent's configuration (map_items, unpack_items, log_stats, lazy, overrides)
when constructing the returned Compose so the flattened pipeline is equivalent
to the original (mirroring OneOf.flatten behavior). Locate the Compose.flatten
implementation and change the return from Compose(new_transforms) to
Compose(new_transforms, map_items=self.map_items,
unpack_items=self.unpack_items, log_stats=self.log_stats, lazy=self.lazy,
overrides=self.overrides) or otherwise propagate those attributes from self into
the new Compose instance.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@monai/transforms/compose.py`:
- Around line 354-361: Compose.flatten() currently builds and returns a new
Compose instance with only new_transforms, dropping parent attributes; update
Compose.flatten to preserve and pass through the parent's configuration
(map_items, unpack_items, log_stats, lazy, overrides) when constructing the
returned Compose so the flattened pipeline is equivalent to the original
(mirroring OneOf.flatten behavior). Locate the Compose.flatten implementation
and change the return from Compose(new_transforms) to Compose(new_transforms,
map_items=self.map_items, unpack_items=self.unpack_items,
log_stats=self.log_stats, lazy=self.lazy, overrides=self.overrides) or otherwise
propagate those attributes from self into the new Compose instance.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 4cdba609-bf44-46ad-a65c-afd42771ee1d

📥 Commits

Reviewing files that changed from the base of the PR and between 4595a45 and f732583.

📒 Files selected for processing (2)
  • monai/transforms/compose.py
  • tests/transforms/compose/test_compose.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • tests/transforms/compose/test_compose.py

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Option to Disable Compose list/tuple Expansion

1 participant