Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion losses/L1Loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,19 @@ def forward(
targets: torch.Tensor,
**kwargs,
) -> torch.Tensor:
"""Return mean L1 loss for backpropagation."""
"""Compute mean L1 training loss for one batch.

Args:
generated_predictions: Model predictions.
targets: Ground-truth targets with matching shape.
**kwargs: Additional unused loss arguments.

Returns:
Scalar mean absolute error used for optimization.

Raises:
ValueError: If prediction and target shapes differ.
"""

if generated_predictions.shape != targets.shape:
raise ValueError("The generated predictions and targets must be the same shape.")
Expand Down
57 changes: 57 additions & 0 deletions losses/WassersteinGeneratorCrossZamirskiLoss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import torch
from torch import nn


class WassersteinGeneratorCrossZamirskiLoss(nn.Module):
"""Generator loss combining L1 reconstruction and Wasserstein term."""

def __init__(self, reconstruction_importance: float = 100.0) -> None:
"""Configure weighting for the reconstruction component.

Args:
reconstruction_importance: Multiplier applied to mean L1 reconstruction loss.
"""

super().__init__()
self.reconstruction_importance = reconstruction_importance

def forward(
self,
fake_classification_outputs: torch.Tensor,
generated_predictions: torch.Tensor,
targets: torch.Tensor,
epoch: int = 0,
**kwargs,
) -> torch.Tensor:
"""Compute Zamirski-style generator objective for one batch.

Args:
fake_classification_outputs: Critic outputs for generated samples.
generated_predictions: Generator predictions.
targets: Ground-truth targets with matching shape.
epoch: Zero-based epoch index used to down-weight adversarial term over time.
**kwargs: Additional unused loss arguments.

Returns:
Scalar generator loss equal to
``reconstruction_importance * L1(generated_predictions, targets) - mean(fake_classification_outputs) / (epoch + 1)``.

Raises:
ValueError: If prediction and target shapes differ.
ValueError: If critic output batch size does not match predictions batch size.
"""

if generated_predictions.shape != targets.shape:
raise ValueError("generated_predictions and targets must have the same shape.")

batch_size = generated_predictions.size(0)
if fake_classification_outputs.size(0) != batch_size:
raise ValueError(
"fake_classification_outputs batch size must match generated_predictions."
)

reconstruction_loss = torch.nn.functional.l1_loss(
generated_predictions, targets, reduction="mean"
)
adversarial_term = torch.mean(fake_classification_outputs) / (epoch + 1)
return self.reconstruction_importance * reconstruction_loss - adversarial_term
54 changes: 54 additions & 0 deletions losses/WassersteinGradientPenaltyLoss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import torch
from torch import nn


class WassersteinGradientPenaltyLoss(nn.Module):
"""WGAN-GP loss wrapper with trainer-compatible call signature."""

def __init__(self, gradient_penalty_importance: float = 10.0) -> None:
"""Configure weighting for the gradient penalty term.

Args:
gradient_penalty_importance: Multiplier applied to gradient penalty.
"""

super().__init__()
self.gradient_penalty_importance = gradient_penalty_importance

def forward(
self,
gradients: torch.Tensor,
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to double check if your intention here is to have the interpolation step in training orchestration, outside of loss module.

real_classification_outputs: torch.Tensor,
fake_classification_outputs: torch.Tensor,
**kwargs,
) -> torch.Tensor:
"""Compute critic loss with Wasserstein distance and gradient penalty.

Args:
gradients: Gradients of critic outputs w.r.t. interpolated inputs.
real_classification_outputs: Critic outputs for real samples.
fake_classification_outputs: Critic outputs for generated samples.
**kwargs: Additional unused loss arguments.

Returns:
Scalar critic loss equal to
``mean(fake_classification_outputs) - mean(real_classification_outputs) + gradient_penalty_importance * penalty``.

Raises:
ValueError: If real critic output batch size does not match gradients batch size.
ValueError: If fake critic output batch size does not match gradients batch size.
"""

batch_size = gradients.size(0)
if real_classification_outputs.size(0) != batch_size:
raise ValueError("real_classification_outputs batch size must match gradients.")
if fake_classification_outputs.size(0) != batch_size:
raise ValueError("fake_classification_outputs batch size must match gradients.")

gradients = gradients.view(batch_size, -1)
gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
return (
torch.mean(fake_classification_outputs)
- torch.mean(real_classification_outputs)
+ gradient_penalty * self.gradient_penalty_importance
)
3 changes: 3 additions & 0 deletions losses/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .L1Loss import L1Loss
from .WassersteinGeneratorCrossZamirskiLoss import WassersteinGeneratorCrossZamirskiLoss
from .WassersteinGradientPenaltyLoss import WassersteinGradientPenaltyLoss
99 changes: 99 additions & 0 deletions models/global_discriminator/global_discriminator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
"""Global image critic used for adversarial training.

This module defines a compact convolutional discriminator (critic) suitable for
Wasserstein-style GAN training (for example WGAN-GP). The network downsamples
the full input image through strided convolutions, aggregates global context
with adaptive average pooling, and produces one unconstrained scalar score per
sample.

Notes:
- The output is a critic score, not a probability.
- No sigmoid activation is applied at the end.
- Inputs are expected in ``(N, C, H, W)`` format.
"""

import torch
from torch import nn

class GlobalDiscriminator(nn.Module):
"""Convolutional global critic for image-level real/fake scoring.

Architecture:
1. Four ``Conv2d + LeakyReLU`` blocks with stride 2 for progressive
spatial downsampling.
2. ``AdaptiveAvgPool2d(1)`` to collect global features independent of
input spatial size.
3. Linear projection to a single scalar critic score per image.

Args:
in_channels: Number of channels in the input image tensor.
base_channels: Number of feature channels in the first convolutional
block. Later blocks scale this as ``x2``, ``x4``, and ``x8``.
num_blocks: Number of strided convolutional blocks used for
downsampling.
max_channels: Optional upper bound for feature channel width in deeper
blocks. If ``None``, channels are uncapped.
"""

def __init__(
self,
in_channels: int = 1,
base_channels: int = 64,
num_blocks: int = 4,
max_channels: int | None = None,
):
"""Initialize the global discriminator network.

Args:
in_channels: Number of channels in each input image.
base_channels: Channel width used by the first convolutional block.
num_blocks: Number of stride-2 convolutional feature blocks.
max_channels: Optional cap on block output channels.
"""
super().__init__()

if num_blocks < 1:
raise ValueError(f"Expected num_blocks >= 1, got {num_blocks}.")
if base_channels < 1:
raise ValueError(f"Expected base_channels >= 1, got {base_channels}.")
if max_channels is not None and max_channels < 1:
raise ValueError(f"Expected max_channels >= 1, got {max_channels}.")

feature_blocks: list[nn.Module] = []
in_ch = in_channels
out_ch = base_channels

for block_idx in range(num_blocks):
if block_idx > 0:
out_ch = base_channels * (2 ** block_idx)
if max_channels is not None:
out_ch = min(out_ch, max_channels)

feature_blocks.extend(
[
nn.Conv2d(in_ch, out_ch, 4, 2, 1),
nn.LeakyReLU(0.2),
]
)
in_ch = out_ch

self.features = nn.Sequential(*feature_blocks)

self.classifier = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Flatten(),
nn.Linear(in_ch, 1),
)

def forward(self, x):
"""Compute critic scores for a batch of images.

Args:
x: Input tensor of shape ``(batch, in_channels, height, width)``.

Returns:
Tensor of shape ``(batch,)`` with unconstrained critic scores.
Larger values indicate samples judged as more real by the critic.
"""

return self.classifier(self.features(x)).squeeze(-1)
File renamed without changes.
File renamed without changes.