Skip to content

Help with training scripts #4

@BOR54

Description

@BOR54

Thanks for the snippets. They're really helpful. Also, I just checked your proposed timeline and its seems you might have finished revising the code. Are you able to share the updated training script? If not, are you able to share the training script as it is with the older lightning?

Here is my current script. I'm currently unable to replicate the paper's finding. I use your configuration file. I'm currently modifying the loss functions to use your snippets.

To provide additional information as requested, I have uploaded project files into the latent space inpainting repo on my profile.

To clarify my training conditions, I am training on an optical coherence tomography image dataset (standardized to 256x256 resolution). I just finished acquiring the ImageNet dataset which I will use to replicate your implementation going forward.

Based on your feedback and a deep dive into the code, I realized the default script parameters had several mismatches. Here is the current setup for the litevae with U-net discriminator I am using after the latest corrections:

  • Model Architecture & Latent Space:
  • Global Batch Size:8 (distributed across 2, H100 GPUs).
  • Learning Rate: $1.0 \times 10^{-4}$ with an Adam optimizer ($\beta_1=0.5, \beta_2=0.9$).
  • Latent Dim Alignment: I discovered a mismatch between the Encoder and Decoder. With use_quant: False, the embed_dim is set to 16, and I have now ensured the Encoder outputs 32 channels (for $\mu$ and $\sigma$) and the Decoder z_channels is set to 16.2.
  • Loss Configuration & Weights:I have implemented a custom LiteVAEAuthorGANLoss module to match the paper's requirements:Reconstruction: L1 Loss (weight: 1.0).
  • KL Divergence: Max weight $1.0 \times 10^{-6}$ with an annealing schedule of 10,000 steps to prevent posterior collapse.
  • Wavelet Loss: I am using a Haar-based DWT loss (weight: 0.1) using Charbonnier distance on the high-frequency sub-bands (LH, HL, HH).
  • Adversarial: Hinge GAN loss (weight: 0.1) using the Unet_Discriminator. The discriminator starts after 20,000 steps (disc_start) to allow the reconstruction to stabilize first.

Implementation Fixes:I encountered and resolved an IndexError in the wavelet computation where the DWT coefficients were being indexed incorrectly, and I ensured that the last_layer gradient adaptive weighting is handled correctly in the forward pass.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions