Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 4 additions & 18 deletions torchTextClassifiers/model/components/classification_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ def __init__(
"""
super().__init__()
if net is not None:
self.net = net

# --- Custom net should either be a Sequential or a Linear ---
if not (isinstance(net, nn.Sequential) or isinstance(net, nn.Linear)):
raise ValueError("net must be an nn.Sequential when provided.")
Expand All @@ -43,7 +45,6 @@ def __init__(
# --- Extract features ---
self.input_dim = first.in_features
self.num_classes = last.out_features
self.net = net
else: # if not Sequential, it is a Linear
self.input_dim = net.in_features
self.num_classes = net.out_features
Expand All @@ -53,23 +54,8 @@ def __init__(
input_dim is not None and num_classes is not None
), "Either net or both input_dim and num_classes must be provided."
self.net = nn.Linear(input_dim, num_classes)
self.input_dim, self.num_classes = self._get_linear_input_output_dims(self.net)
self.input_dim = input_dim
self.num_classes = num_classes

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.net(x)

@staticmethod
def _get_linear_input_output_dims(module: nn.Module):
"""
Returns (input_dim, output_dim) for any module containing Linear layers.
Works for Linear, Sequential, or nested models.
"""
# Collect all Linear layers recursively
linears = [m for m in module.modules() if isinstance(m, nn.Linear)]

if not linears:
raise ValueError("No Linear layers found in the given module.")

input_dim = linears[0].in_features
output_dim = linears[-1].out_features
return input_dim, output_dim
Loading