Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions modelopt/torch/distill/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .config import *
from .distillation import *
from .distillation_model import *
from .layerwise_distillation_model import *
from .loss_balancers import *
from .losses import *
from .registry import *
Expand Down
21 changes: 20 additions & 1 deletion modelopt/torch/distill/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

from .loss_balancers import DistillationLossBalancer

__all__ = ["KDLossConfig"]
__all__ = ["ExportStudentConfig", "KDLossConfig", "LayerwiseKDConfig"]

Criterion = Union[Loss, dict[tuple[str, str], Loss]] # noqa: UP007

Expand Down Expand Up @@ -120,6 +120,25 @@ def _strict_validate(self) -> None:
)


class LayerwiseKDConfig(KDLossConfig):
"""Configuration for the Layerwise Knowledge-Distillation mode.
This mode is used to distill knowledge from a teacher model to a student model using layerwise distillation.
"""

@pydantic.field_validator("criterion")
@classmethod
def format_criterion(cls, criterion: Criterion | None) -> dict[tuple[str, str], Loss]:
"""Ensure criterion is a mapping from layer names to loss (potentially entire module)."""
if not isinstance(criterion, dict):
raise ValueError("Layerwise Distillation mode requires explicit criterion pairs.")
if any(key == ("", "") for key in criterion):
raise ValueError(
"Layerwise Distillation mode does not support output-only distillation."
)
return criterion


class ExportStudentConfig(ModeloptBaseConfig):
"""Configuration for the export_student mode.
Expand Down
62 changes: 33 additions & 29 deletions modelopt/torch/distill/distillation_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.

"""Meta-model wrapper to support knowledge-distillation learning."""

import inspect
Expand Down Expand Up @@ -45,6 +43,7 @@ def _setup(self):
self._register_temp_attribute("_loss_modules", nn.ModuleList())
self._register_temp_attribute("_only_teacher_fwd", False)
self._register_temp_attribute("_only_student_fwd", False)
self._register_temp_attribute("_hook_handles", set())

# HACK: set model's forward signature to match student class' original.
# Needed for HF `transformers.utils.find_labels` which relies on inspecting class signature.
Expand All @@ -57,23 +56,22 @@ def _setup(self):

def modify(
self,
teacher_model: nn.Module, # To be frozen.
teacher_model: nn.Module,
criterion: dict[
tuple[
str, # Student model layer whose output to capture.
str, # Teacher model layer whose output to capture.
str, # Student model layer whose output to capture
str, # Teacher model layer whose output to capture
],
Loss, # Loss fn.
Loss, # Loss function
],
loss_balancer: DistillationLossBalancer | None = None,
expose_minimal_state_dict: bool = True,
):
"""Constructor.

Args:
teacher_model: A teacher model which this class would encapsulate.
criterion: A dictionary mapping the tuple of student and teacher
model layer names to the loss function to apply to that layer pair.
teacher_model: The teacher model (will be frozen).
criterion: Dictionary mapping (student_layer_name, teacher_layer_name) to loss functions.
loss_balancer: Instance of
:class:`DistillationLossBalancer <modelopt.torch.distill.DistillationLossBalancer>`
which reduces distillation and non-distillation losses into a single value using some weighing scheme.
Expand Down Expand Up @@ -106,22 +104,30 @@ def modify(
{m for m in self._layers_to_loss.values() if len(list(m.parameters())) > 0}
)

# Disable grad for teacher
# Disable grad for teacher.
self._teacher_model.requires_grad_(False)

# Register hooks for intermediate outputs from teacher models and the student model.
# HACK: For inexplicable reasons, sometimes a model will have hooks remain after
# `ato.restore()` so we check if they are present accidentally first.
# Use hooks to caputure relevant activation tensors for loss computation.
self._register_hooks()

def _register_hooks(self):
"""Register hooks for intermediate tensors from teacher models and the student model."""
for student_layer, teacher_layer in self._layers_to_loss:
setattr(student_layer, "_intermediate_output", None)
if student_output_capture_fwd_hook not in student_layer._forward_hooks.values():
student_layer.register_forward_hook(student_output_capture_fwd_hook)
handle_s = student_layer.register_forward_hook(student_output_capture_fwd_hook)
setattr(teacher_layer, "_intermediate_output", None)
if teacher_output_capture_fwd_hook not in teacher_layer._forward_hooks.values():
teacher_layer.register_forward_hook(teacher_output_capture_fwd_hook)
handle_t = teacher_layer.register_forward_hook(teacher_output_capture_fwd_hook)
self._hook_handles.update([handle_s, handle_t])

def export(self):
"""Export the distillation model."""
for handle in self._hook_handles:
handle.remove()
self._hook_handles.clear()
return super().export()

@property
def teacher_model(self) -> nn.ModuleList:
def teacher_model(self) -> nn.Module:
"""Fetch the teacher model."""
return self._teacher_model

Expand All @@ -148,7 +154,7 @@ def hide_teacher_model(self, enable=True):

@contextmanager
def hide_loss_modules(self, enable=True):
"""Context manager to temporarily hide teacher model from the model."""
"""Context manager to temporarily hide loss modules from the model."""
loss_modules = self._loss_modules
if enable:
self._loss_modules = nn.ModuleList()
Expand All @@ -169,7 +175,7 @@ def only_teacher_forward(self, enable=True):

@contextmanager
def only_student_forward(self, enable=True):
"""Context manager to temporarily disable forward passes on the student model."""
"""Context manager to temporarily run forward passes only on the student model."""
if enable:
self._only_student_fwd = True
try:
Expand Down Expand Up @@ -245,15 +251,13 @@ def compute_kd_loss(

Args:
student_loss: Original loss computed from the student's output.
loss_reduction_fn: Callable to be called on each loss tensor prior to balancing. Useful for
loss-masking situations where the callable changes arguments each iteration.
loss_reduction_fn: Callable to be called on each loss tensor prior to balancing.
Useful for loss-masking situations where the callable changes arguments each iteration.
skip_balancer: Whether or not to use loss balancer to reduce the loss dict into a scalar.
**loss_fn_kwargs: Additional keyword arguments to be passed to the loss function, if needed.
This facilitates losses that require extras, such as labels for ``mtd.MFTLoss``.

Returns:
If reduce is True, the scalar total loss weighted between ``student_loss`` and the distillation losses.
If reduce is False, a dict of student model output loss and layer-wise distillation losses.
A dict of losses if skip_balancer is True, else the scalar total loss.
"""
if self._loss_balancer is None:
assert student_loss is None, "Cannot pass in student loss without using Loss Balancer."
Expand Down Expand Up @@ -288,9 +292,9 @@ def compute_kd_loss(
return loss_total


def student_output_capture_fwd_hook(module: nn.Module, input: Any, output: Any): # pylint: disable=redefined-builtin
def student_output_capture_fwd_hook(module: nn.Module, input: Any, output: Any):
"""A hook to capture layer output."""
# NOTE: Defined externally to allow pickling.
# NOTE: Defined externally to allow pickling during DDP initialization.

if getattr(module, "_only_teacher_fwd", False):
return # Might be hooked on entire model fwd
Expand All @@ -303,9 +307,9 @@ def student_output_capture_fwd_hook(module: nn.Module, input: Any, output: Any):
module._intermediate_output = output


def teacher_output_capture_fwd_hook(module: nn.Module, input: Any, output: Any): # pylint: disable=redefined-builtin
def teacher_output_capture_fwd_hook(module: nn.Module, input: Any, output: Any):
"""A hook to capture layer output."""
# NOTE: Defined externally to allow pickling.
# NOTE: Defined externally to allow pickling during DDP initialization.

if module._intermediate_output is not None:
# NOTE: cannot tell if train or eval since teacher is always eval
Expand Down
88 changes: 88 additions & 0 deletions modelopt/torch/distill/layerwise_distillation_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Meta-model wrapper to support layerwise-enabled knowledge-distillation learning."""

import warnings
from typing import Any

import torch.nn as nn

from .distillation_model import DistillationModel, student_output_capture_fwd_hook

__all__ = ["LayerwiseDistillationModel"]


class LayerwiseDistillationModel(DistillationModel):
"""Meta-model wrapper to support layerwise-enabled knowledge-distillation learning.

The LayerwiseDistillationModel is a subclass of the DistillationModel that injects teacher inputs
into the corresponding student layers. This accomodates the case where the student model is the
teacher with specific submodules replaced, which now need to be trained to mimic the original
submodule in the teacher.
"""

def _register_hooks(self):
"""Register hooks for intermediate tensors from teacher models and the student model."""
for student_layer, teacher_layer in self._layers_to_loss:
setattr(student_layer, "_teacher_layer", [teacher_layer])
handle_s1 = student_layer.register_forward_pre_hook(student_input_bypass_fwd_hook)
setattr(student_layer, "_intermediate_output", None)
handle_s2 = student_layer.register_forward_hook(student_output_capture_fwd_hook)
setattr(teacher_layer, "_intermediate_input", None)
setattr(teacher_layer, "_intermediate_output", None)
handle_t = teacher_layer.register_forward_hook(teacher_input_output_capture_fwd_hook)
self._hook_handles.update([handle_s1, handle_s2, handle_t])

def export(self):
"""Export the distillation model."""
for student_layer, _ in self._layers_to_loss:
delattr(student_layer, "_teacher_layer")
return super().export()


def student_input_bypass_fwd_hook(module: nn.Module, input: Any):
"""A hook to inject teacher input into corresponding student layer."""
# NOTE: Defined externally to allow pickling during DDP initialization.

if getattr(module, "_only_teacher_fwd", False):
return input # Might be hooked on entire model fwd

teacher_layer = module._teacher_layer[0]
teacher_input = teacher_layer._intermediate_input
if teacher_input is None:
warnings.warn(
f"Teacher's Module `{type(teacher_layer).__name__}` has no intermediate input stored."
" This is expected when the `only_student_forward` context manager is in use."
)
return input

teacher_layer._intermediate_input = None # reset
return teacher_input


def teacher_input_output_capture_fwd_hook(module: nn.Module, input: Any, output: Any):
"""A hook to capture layer input and output."""
# NOTE: Defined externally to allow pickling during DDP initialization.

if module._intermediate_output is not None:
# NOTE: cannot tell if train or eval since teacher is always eval
warnings.warn(
f"Teacher's Module `{type(module).__name__}` already has an intermediate output stored."
" This is expected when `DistillationModel.compute_kd_loss` is not called in eval mode."
)

module._intermediate_input = input
module._intermediate_output = output
48 changes: 32 additions & 16 deletions modelopt/torch/distill/mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,23 +21,21 @@
import warnings

import torch.nn as nn
from torch.nn.modules.loss import _Loss as Loss

from modelopt.torch.opt.config import ModeloptBaseConfig
from modelopt.torch.opt.conversion import ModeloptStateManager
from modelopt.torch.opt.mode import (
ConvertEntrypoint,
ConvertReturnType,
MetadataDict,
ModeDescriptor,
RestoreEntrypoint,
UpdateEntrypoint,
_ModeRegistryCls,
)
from modelopt.torch.utils import init_model_from_model_like, unwrap_model

from .config import ExportStudentConfig, KDLossConfig
from .config import ExportStudentConfig, KDLossConfig, LayerwiseKDConfig
from .distillation_model import DistillationModel
from .layerwise_distillation_model import LayerwiseDistillationModel
from .registry import DistillationDMRegistry

DistillModeRegistry = _ModeRegistryCls("distill")
Expand Down Expand Up @@ -75,17 +73,35 @@ def restore(self) -> RestoreEntrypoint:
"""The mode's entrypoint for restoring a model."""
raise NotImplementedError(f"{self.name} mode does not support restore.")

@property
def update_for_new_mode(self) -> UpdateEntrypoint:
"""The mode's entrypoint for updating the models state for adding new mode."""
return _reset_kd_state_config

@property
def save_mode_in_state(self) -> bool:
"""Whether the mode should be saved into the modelopt state."""
return False


@DistillModeRegistry.register_mode
class LayerwiseKDModeDescriptor(KnowledgeDistillationModeDescriptor):
"""Class to describe the Layerwise Knowledge-Distillation mode.

The properties of this mode can be inspected via the source code.
"""

@property
def name(self) -> str:
"""Returns the value (str representation) of the mode."""
return "layerwise_kd"

@property
def config_class(self) -> type[ModeloptBaseConfig]:
"""Specifies the config class for the mode."""
return LayerwiseKDConfig

@property
def convert(self) -> ConvertEntrypoint:
"""The mode's entrypoint for converting a model."""
return _convert_for_layerwise


@DistillModeRegistry.register_mode
class ExportStudentModeDescriptor(ModeDescriptor):
"""Class to describe the specific Export mode to be used with Knowledge Distillation.
Expand Down Expand Up @@ -124,7 +140,9 @@ def save_mode_in_state(self) -> bool:
return False


def _convert_for_kd(model: nn.Module, config: KDLossConfig) -> ConvertReturnType:
def _convert_for_kd(
model: nn.Module, config: KDLossConfig, model_cls: type[nn.Module] = DistillationModel
) -> ConvertReturnType:
"""Function for converting a model to a distillation meta-model.

This is the only utility needed to use the ``modelopt.torch.distill`` API directly.
Expand Down Expand Up @@ -159,7 +177,7 @@ def _convert_for_kd(model: nn.Module, config: KDLossConfig) -> ConvertReturnType
# initialize distillation model
original_cls = type(student)
if original_cls not in DistillationDMRegistry:
DistillationDMRegistry.register({original_cls: "student_class"})(DistillationModel)
DistillationDMRegistry.register({original_cls: "student_class"})(model_cls)
# TODO (lucasl): look into ways to avoid registering every class manually
# (e.g. by just registering nn.Module and disable the "forward" check for the inherited class check

Expand All @@ -174,11 +192,9 @@ def _convert_for_kd(model: nn.Module, config: KDLossConfig) -> ConvertReturnType
return distillation_model, metadata


def _reset_kd_state_config(model: nn.Module, config: KDLossConfig, metadata: MetadataDict):
"""Function for resetting the state's config."""
config.teacher_model = nn.Module
config.criterion = Loss()
config.loss_balancer = None
def _convert_for_layerwise(model: nn.Module, config: LayerwiseKDConfig) -> ConvertReturnType:
"""Function for converting a model to a layerwise distillation meta-model."""
return _convert_for_kd(model, config, model_cls=LayerwiseDistillationModel)


def _export_student(model: nn.Module, config: ExportStudentConfig) -> ConvertReturnType:
Expand Down
Loading
Loading