Skip to content

Support PyTorch compile #206

@tizianocitro

Description

@tizianocitro

Is your feature request related to a problem? Please describe.

Models currently run eagerly by default and do not expose a first-class way to opt into PyTorch compilation. Users who want to experiment with torch.compile for faster model execution need to wrap modules manually.

Describe the solution you would like

Add opt-in torch.compile support via a compile_config constructor argument.

The built-in support should compile the model submodules (e.g., encoder and decoder) rather than the entire LightningModule. This keeps Lightning lifecycle methods, logging, metrics, and training-step logic eager while still enabling compilation for the pure PyTorch model paths.

Example:

model = MLPHlpModule(
    encoder_config={"in_channels": 16, "out_channels": 32},
    compile_config=True,
)

Or with explicit compile options:

model = MLPHlpModule(
    encoder_config={"in_channels": 16, "out_channels": 32},
    compile_config={"backend": "inductor", "dynamic": True},
)

Default behavior should remain unchanged when compile_config is omitted.

Describe alternatives you've considered

One alternative is compiling the full LightningModule with torch.compile(model) before passing it to Trainer.fit(). Lightning supports this, but it can be fragile for modules that use self.log() and other Lightning runtime behavior. Our modules do this heavily, so compiling only encoder and decoder is a safer default.

Additional context

Lightning's compile documentation says that self.log() can cause compile issues and recommends compiling submodules as a workaround: https://lightning.ai/docs/pytorch/stable/advanced/compile.html.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions