Skip to content

Migrate DiT components to nn#1530

Merged
pzharrington merged 5 commits intoNVIDIA:mainfrom
pzharrington:dit-nn-migration
Mar 25, 2026
Merged

Migrate DiT components to nn#1530
pzharrington merged 5 commits intoNVIDIA:mainfrom
pzharrington:dit-nn-migration

Conversation

@pzharrington
Copy link
Copy Markdown
Collaborator

PhysicsNeMo Pull Request

Description

Migrates the building block components (attention modules, tokenizers/detokenizers, etc) for the DiT architecture into nn for reuse across DiT-like model architectures. Unit tests and corresponding data also migrated for better/intuitive visibility.

Checklist

Dependencies

Review Process

All PRs are reviewed by the PhysicsNeMo team before merging.

Depending on which files are changed, GitHub may automatically assign a maintainer for review.

We are also testing AI-based code review tools (e.g., Greptile), which may add automated comments with a confidence score.
This score reflects the AI’s assessment of merge readiness and is not a qualitative judgment of your work, nor is
it an indication that the PR will be accepted / rejected.

AI-generated feedback should be reviewed critically for usefulness.
You are not required to respond to every AI comment, but they are intended to help both authors and reviewers.
Please react to Greptile comments with 👍 or 👎 to provide feedback on their accuracy.

@pzharrington pzharrington self-assigned this Mar 23, 2026
@pzharrington
Copy link
Copy Markdown
Collaborator Author

@greptileai

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Mar 23, 2026

Greptile Summary

This PR migrates DiT building-block components (DiTBlock, attention modules, tokenizers/detokenizers, conditioning embedders) from physicsnemo/models/dit/ into physicsnemo/nn/module/ so they can be reused across DiT-like model architectures, while keeping physicsnemo.models.dit.DiT in place. Backward compatibility is preserved via PEP 562 __getattr__ lazy shims in physicsnemo/models/dit/__init__.py and module-level shim files for the sub-modules (layers.py, conditioning_embedders.py), all emitting LegacyFeatureWarning. The full physicsnemo.experimental.models.dit package (which was itself already a shim) is removed without replacement.

Key changes:

  • physicsnemo/nn/module/dit_layers.py and physicsnemo/nn/module/conditioning_embedders.py are the new canonical locations for all DiT layer/embedder components.
  • All symbols are re-exported from the top-level physicsnemo.nn namespace.
  • DiTBlock.forward gains a minor API improvement: attn_kwargs default changed from the mutable {} to None (handled internally via attn_kwargs or {}).
  • Unit tests for DiTBlock and conditioning embedders are migrated to test/nn/module/, importing from the new locations.
  • Two style-level suggestions: add __all__ = ["DiT"] to physicsnemo/models/dit/__init__.py to prevent import * from leaking shim-internal names, and add test coverage for the new deprecation warning paths in physicsnemo.models.dit.

Important Files Changed

Filename Overview
physicsnemo/models/dit/init.py Adds PEP 562 __getattr__ shim for lazy backward-compatible re-export of deprecated names with LegacyFeatureWarning. Missing __all__ = ["DiT"] means import * leaks warnings and LegacyFeatureWarning into calling namespaces.
physicsnemo/nn/module/dit_layers.py New home for all DiT layer components (DiTBlock, attention modules, tokenizers, detokenizers). Functionally identical to the removed physicsnemo/models/dit/layers.py; imports updated to use fully qualified physicsnemo.nn.module.* paths. DiTBlock.forward signature improved: mutable-default attn_kwargs={} replaced with None.
physicsnemo/nn/module/conditioning_embedders.py New home for conditioning embedder classes (DiTConditionEmbedder, EDMConditionEmbedder, ZeroConditioningEmbedder) and factory. Direct port from the old physicsnemo/models/dit/conditioning_embedders.py with updated cross-reference paths in docstrings.
physicsnemo/models/dit/conditioning_embedders.py Replaced with a thin backward-compatibility shim that emits a module-level LegacyFeatureWarning on import and re-exports all names from the new physicsnemo.nn.module.conditioning_embedders.
physicsnemo/models/dit/layers.py Replaced with a thin backward-compatibility shim that emits a module-level LegacyFeatureWarning on import and re-exports all names from the new physicsnemo.nn.module.dit_layers.
physicsnemo/models/dit/dit.py Updated to consolidate all layer and embedder imports into a single from physicsnemo.nn import (...) block; docstring cross-references updated to new physicsnemo.nn.* paths. No logic changes.
physicsnemo/nn/init.py Adds exports for all new DiT layer and conditioning embedder symbols so they are accessible via the top-level physicsnemo.nn namespace.
test/nn/module/test_dit_layers.py New test file covering DiTBlock accuracy (timm, natten, transformer_engine backends), exception paths, and per-sample dropout behavior. Tests import directly from the new physicsnemo.nn.module.dit_layers location.
test/nn/module/test_conditioning_embedders.py New test file covering DiTConditionEmbedder, EDMConditionEmbedder, and ZeroConditioningEmbedder forward accuracy and constructor variants.
test/models/dit/test_dit.py DiTBlock-specific tests removed and migrated to test/nn/module/test_dit_layers.py. Imports updated to use physicsnemo.nn.module.dit_layers. No coverage added for new backward-compatibility deprecation shims in physicsnemo.models.dit.

Comments Outside Diff (2)

  1. physicsnemo/models/dit/__init__.py, line 37-76 (link)

    Missing __all__ exposes internal imports via star-import

    Without an explicit __all__, a from physicsnemo.models.dit import * will export warnings and LegacyFeatureWarning alongside DiT, since those names don't start with _. Adding __all__ = ["DiT"] keeps the public surface clean and clearly signals the intended API of this shim module.

  2. test/models/dit/test_dit.py, line 30 (link)

    No tests covering the new deprecation shims

    The PR removes test_experimental_dit_import_warns (which verified the old experimental → stable deprecation path) but does not add equivalent tests for the three new deprecation shims:

    • physicsnemo.models.dit.__getattr__ — lazy LegacyFeatureWarning on attribute access (e.g. from physicsnemo.models.dit import DiTBlock)
    • physicsnemo.models.dit.conditioning_embedders — module-level LegacyFeatureWarning on import
    • physicsnemo.models.dit.layers — module-level LegacyFeatureWarning on import

    Without tests, a future refactor could silently drop or mis-configure these warnings and no CI gate would catch it. Consider adding a small test (likely in test/models/dit/test_dit.py or a new test_deprecations.py) similar to the one that was removed:

    def test_models_dit_getattr_warns():
        import importlib, sys
        sys.modules.pop("physicsnemo.models.dit", None)
        import physicsnemo.models.dit as m
        with pytest.warns(LegacyFeatureWarning):
            _ = m.DiTBlock

Reviews (4): Last reviewed commit: "Use legacy feature warning" | Re-trigger Greptile

Comment thread physicsnemo/models/dit/__init__.py
@pzharrington
Copy link
Copy Markdown
Collaborator Author

@greptileai

@pzharrington
Copy link
Copy Markdown
Collaborator Author

@greptileai

@pzharrington
Copy link
Copy Markdown
Collaborator Author

/blossom-ci

@pzharrington pzharrington marked this pull request as ready for review March 24, 2026 17:27
Copy link
Copy Markdown
Collaborator

@loliverhennigh loliverhennigh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got around to looking through today. Not super familiar with these layers but the structure looks good and consistent now.

@pzharrington pzharrington added this pull request to the merge queue Mar 25, 2026
@github-merge-queue github-merge-queue Bot removed this pull request from the merge queue due to failed status checks Mar 25, 2026
@pzharrington
Copy link
Copy Markdown
Collaborator Author

/blossom-ci

@pzharrington pzharrington enabled auto-merge March 25, 2026 19:07
@pzharrington pzharrington added this pull request to the merge queue Mar 25, 2026
Merged via the queue into NVIDIA:main with commit c667944 Mar 25, 2026
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants