-
Notifications
You must be signed in to change notification settings - Fork 1
Include WGAN-GP Model and Losses #23
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
MattsonCam
wants to merge
8
commits into
main
Choose a base branch
from
include_wgan_gp
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
ea521c5
docs(wgan_gp): improve GlobalDiscriminator documentation and docstrings
2c39ad9
refactor(wgan_gp): improve GlobalDiscriminator docs and configurability
0efd09d
Renamed the folder of the global discriminator
7b45875
Add the cross zamirski generator loss
6010d53
Add the wgan-gp loss for the discrimintator
6a697e1
Update file pathing to include the wasserstein loss components
6b69d04
docs(losses): align Wasserstein and L1 loss docstrings with metric style
26245fe
Moved unet to it's own model folder
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,57 @@ | ||
| import torch | ||
| from torch import nn | ||
|
|
||
|
|
||
| class WassersteinGeneratorCrossZamirskiLoss(nn.Module): | ||
| """Generator loss combining L1 reconstruction and Wasserstein term.""" | ||
|
|
||
| def __init__(self, reconstruction_importance: float = 100.0) -> None: | ||
| """Configure weighting for the reconstruction component. | ||
|
|
||
| Args: | ||
| reconstruction_importance: Multiplier applied to mean L1 reconstruction loss. | ||
| """ | ||
|
|
||
| super().__init__() | ||
| self.reconstruction_importance = reconstruction_importance | ||
|
|
||
| def forward( | ||
| self, | ||
| fake_classification_outputs: torch.Tensor, | ||
| generated_predictions: torch.Tensor, | ||
| targets: torch.Tensor, | ||
| epoch: int = 0, | ||
| **kwargs, | ||
| ) -> torch.Tensor: | ||
| """Compute Zamirski-style generator objective for one batch. | ||
|
|
||
| Args: | ||
| fake_classification_outputs: Critic outputs for generated samples. | ||
| generated_predictions: Generator predictions. | ||
| targets: Ground-truth targets with matching shape. | ||
| epoch: Zero-based epoch index used to down-weight adversarial term over time. | ||
| **kwargs: Additional unused loss arguments. | ||
|
|
||
| Returns: | ||
| Scalar generator loss equal to | ||
| ``reconstruction_importance * L1(generated_predictions, targets) - mean(fake_classification_outputs) / (epoch + 1)``. | ||
|
|
||
| Raises: | ||
| ValueError: If prediction and target shapes differ. | ||
| ValueError: If critic output batch size does not match predictions batch size. | ||
| """ | ||
|
|
||
| if generated_predictions.shape != targets.shape: | ||
| raise ValueError("generated_predictions and targets must have the same shape.") | ||
|
|
||
| batch_size = generated_predictions.size(0) | ||
| if fake_classification_outputs.size(0) != batch_size: | ||
| raise ValueError( | ||
| "fake_classification_outputs batch size must match generated_predictions." | ||
| ) | ||
|
|
||
| reconstruction_loss = torch.nn.functional.l1_loss( | ||
| generated_predictions, targets, reduction="mean" | ||
| ) | ||
| adversarial_term = torch.mean(fake_classification_outputs) / (epoch + 1) | ||
| return self.reconstruction_importance * reconstruction_loss - adversarial_term |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,54 @@ | ||
| import torch | ||
| from torch import nn | ||
|
|
||
|
|
||
| class WassersteinGradientPenaltyLoss(nn.Module): | ||
| """WGAN-GP loss wrapper with trainer-compatible call signature.""" | ||
|
|
||
| def __init__(self, gradient_penalty_importance: float = 10.0) -> None: | ||
| """Configure weighting for the gradient penalty term. | ||
|
|
||
| Args: | ||
| gradient_penalty_importance: Multiplier applied to gradient penalty. | ||
| """ | ||
|
|
||
| super().__init__() | ||
| self.gradient_penalty_importance = gradient_penalty_importance | ||
|
|
||
| def forward( | ||
| self, | ||
| gradients: torch.Tensor, | ||
| real_classification_outputs: torch.Tensor, | ||
| fake_classification_outputs: torch.Tensor, | ||
| **kwargs, | ||
| ) -> torch.Tensor: | ||
| """Compute critic loss with Wasserstein distance and gradient penalty. | ||
|
|
||
| Args: | ||
| gradients: Gradients of critic outputs w.r.t. interpolated inputs. | ||
| real_classification_outputs: Critic outputs for real samples. | ||
| fake_classification_outputs: Critic outputs for generated samples. | ||
| **kwargs: Additional unused loss arguments. | ||
|
|
||
| Returns: | ||
| Scalar critic loss equal to | ||
| ``mean(fake_classification_outputs) - mean(real_classification_outputs) + gradient_penalty_importance * penalty``. | ||
|
|
||
| Raises: | ||
| ValueError: If real critic output batch size does not match gradients batch size. | ||
| ValueError: If fake critic output batch size does not match gradients batch size. | ||
| """ | ||
|
|
||
| batch_size = gradients.size(0) | ||
| if real_classification_outputs.size(0) != batch_size: | ||
| raise ValueError("real_classification_outputs batch size must match gradients.") | ||
| if fake_classification_outputs.size(0) != batch_size: | ||
| raise ValueError("fake_classification_outputs batch size must match gradients.") | ||
|
|
||
| gradients = gradients.view(batch_size, -1) | ||
| gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() | ||
| return ( | ||
| torch.mean(fake_classification_outputs) | ||
| - torch.mean(real_classification_outputs) | ||
| + gradient_penalty * self.gradient_penalty_importance | ||
| ) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,3 @@ | ||
| from .L1Loss import L1Loss | ||
| from .WassersteinGeneratorCrossZamirskiLoss import WassersteinGeneratorCrossZamirskiLoss | ||
| from .WassersteinGradientPenaltyLoss import WassersteinGradientPenaltyLoss |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,99 @@ | ||
| """Global image critic used for adversarial training. | ||
|
|
||
| This module defines a compact convolutional discriminator (critic) suitable for | ||
| Wasserstein-style GAN training (for example WGAN-GP). The network downsamples | ||
| the full input image through strided convolutions, aggregates global context | ||
| with adaptive average pooling, and produces one unconstrained scalar score per | ||
| sample. | ||
|
|
||
| Notes: | ||
| - The output is a critic score, not a probability. | ||
| - No sigmoid activation is applied at the end. | ||
| - Inputs are expected in ``(N, C, H, W)`` format. | ||
| """ | ||
|
|
||
| import torch | ||
| from torch import nn | ||
|
|
||
| class GlobalDiscriminator(nn.Module): | ||
| """Convolutional global critic for image-level real/fake scoring. | ||
|
|
||
| Architecture: | ||
| 1. Four ``Conv2d + LeakyReLU`` blocks with stride 2 for progressive | ||
| spatial downsampling. | ||
| 2. ``AdaptiveAvgPool2d(1)`` to collect global features independent of | ||
| input spatial size. | ||
| 3. Linear projection to a single scalar critic score per image. | ||
|
|
||
| Args: | ||
| in_channels: Number of channels in the input image tensor. | ||
| base_channels: Number of feature channels in the first convolutional | ||
| block. Later blocks scale this as ``x2``, ``x4``, and ``x8``. | ||
| num_blocks: Number of strided convolutional blocks used for | ||
| downsampling. | ||
| max_channels: Optional upper bound for feature channel width in deeper | ||
| blocks. If ``None``, channels are uncapped. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| in_channels: int = 1, | ||
| base_channels: int = 64, | ||
| num_blocks: int = 4, | ||
| max_channels: int | None = None, | ||
| ): | ||
| """Initialize the global discriminator network. | ||
|
|
||
| Args: | ||
| in_channels: Number of channels in each input image. | ||
| base_channels: Channel width used by the first convolutional block. | ||
| num_blocks: Number of stride-2 convolutional feature blocks. | ||
| max_channels: Optional cap on block output channels. | ||
| """ | ||
| super().__init__() | ||
|
|
||
| if num_blocks < 1: | ||
| raise ValueError(f"Expected num_blocks >= 1, got {num_blocks}.") | ||
| if base_channels < 1: | ||
| raise ValueError(f"Expected base_channels >= 1, got {base_channels}.") | ||
| if max_channels is not None and max_channels < 1: | ||
| raise ValueError(f"Expected max_channels >= 1, got {max_channels}.") | ||
|
|
||
| feature_blocks: list[nn.Module] = [] | ||
| in_ch = in_channels | ||
| out_ch = base_channels | ||
|
|
||
| for block_idx in range(num_blocks): | ||
| if block_idx > 0: | ||
| out_ch = base_channels * (2 ** block_idx) | ||
| if max_channels is not None: | ||
| out_ch = min(out_ch, max_channels) | ||
|
|
||
| feature_blocks.extend( | ||
| [ | ||
| nn.Conv2d(in_ch, out_ch, 4, 2, 1), | ||
| nn.LeakyReLU(0.2), | ||
| ] | ||
| ) | ||
| in_ch = out_ch | ||
|
|
||
| self.features = nn.Sequential(*feature_blocks) | ||
|
|
||
| self.classifier = nn.Sequential( | ||
| nn.AdaptiveAvgPool2d(1), | ||
| nn.Flatten(), | ||
| nn.Linear(in_ch, 1), | ||
| ) | ||
|
|
||
| def forward(self, x): | ||
| """Compute critic scores for a batch of images. | ||
|
|
||
| Args: | ||
| x: Input tensor of shape ``(batch, in_channels, height, width)``. | ||
|
|
||
| Returns: | ||
| Tensor of shape ``(batch,)`` with unconstrained critic scores. | ||
| Larger values indicate samples judged as more real by the critic. | ||
| """ | ||
|
|
||
| return self.classifier(self.features(x)).squeeze(-1) |
File renamed without changes.
File renamed without changes.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to double check if your intention here is to have the interpolation step in training orchestration, outside of loss module.