Skip to content

How to .update() a GaussianConditional entropy model and properly call .compress() and .decompress()? #337

@johnli25

Description

@johnli25

Hello! I'm trying to train my own custom autoencoder model while integrating EntropyBottleneck and GaussianConditional. Here's a snippet of my class:

class AEWithEntropy(nn.Module):
    def __init__(self, freeze_base=True):
        super().__init__()
        self.base = PNC32() # original ae

        C = self.base.encoder2.out_channels  # 32
        # 1×1 convs for hyperprior (no need to modify PNC32)
        self.hyper_encoder = nn.Conv2d(C, C, kernel_size=1)
        self.hyper_decoder = nn.Conv2d(C, C, kernel_size=1)

        self.entropy_z = EntropyBottleneck(C)
        self.gauss_y = GaussianConditional(None)

        if freeze_base:
            for p in self.base.parameters(): p.requires_grad = False

    def forward(self, x, tail_length=None, quantize_level=0):
        y   = self.base.encode(x) # y is the original compressed/encoded latent feature 
        z   = self.hyper_encoder(y) # z captures additional information (e.g., statistics like variance) about 'y' to improve compression.
        z_q, z_lh = self.entropy_z(z) # z_q is the quantized version of z, and z_lh is the likelihood of z given the model.
        sigma = self.hyper_decoder(z_q) # sigma is used to model the distribution of 'y' (e.g., as a Gaussian with mean and variance).
        y_q, y_lh  = self.gauss_y(y, sigma) # y_q is the quantized version of y, and y_lh is the likelihood of y given the model, which will be fed into decoder/reconstructer
        recon  = self.base.decode(y_q)
        return recon, y_lh, z_lh 
    
    def compress(self, x):
        y = self.base.encode(x)
        z = self.hyper_encoder(y)
        z_bytes = self.entropy_z.compress(z)
        z_q = self.entropy_z.decompress(z_bytes)
        sigma = self.hyper_decoder(z_q)
        y_bytes = self.gauss_y.compress(y, sigma)
        return {"z": z_bytes, "y": y_bytes}

    def decompress(self, streams):
        z_q  = self.entropy_z.decompress(streams["z"])
        sigma = self.hyper_decoder(z_q)
        y_q  = self.gauss_y.decompress(streams["y"], sigma)
        return self.base.decode(y_q)

The model trains just fine, but during evaluation, I'm trying to run the compress() and decompress() methods to print the total number of bytes my model can compress/encode images into. I'm aware I'm supposed to call some .update(), and I successfully do for the EntropyBottleneck via model.entropy_z.update(force=True)`, but I can't seem to do the same thing with GaussianConditional. I notice that I need to do something with CDF/scale tables but I'm stuck here. Here's the full error/output log:

Traceback (most recent call last):
  File "autoencoder_train.py", line 499, in <module>
    final_test_loss, final_psnr, final_ssim = eval_autoencoder(model=model, dataloader=test_loader, criterion=criterion, device=device, max_tail_length=drops, quantize=args.quantize)
  File "autoencoder_train.py", line 255, in eval_autoencoder
    outputs, _, _ = model(x=inputs, tail_length=max_tail_length, quantize_level=quantize)
  File "/mnt/data/envs/grace-test/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/mnt/data/envs/grace-test/lib/python3.8/site-packages/torch/nn/parallel/data_parallel.py", line 171, in forward
    outputs = self.parallel_apply(replicas, inputs, kwargs)
  File "/mnt/data/envs/grace-test/lib/python3.8/site-packages/torch/nn/parallel/data_parallel.py", line 181, in parallel_apply
    return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
  File "/mnt/data/envs/grace-test/lib/python3.8/site-packages/torch/nn/parallel/parallel_apply.py", line 89, in parallel_apply
    output.reraise()
  File "/mnt/data/envs/grace-test/lib/python3.8/site-packages/torch/_utils.py", line 543, in reraise
    raise exception
ValueError: Caught ValueError in replica 0 on device 0.
Original Traceback (most recent call last):
  File "/mnt/data/envs/grace-test/lib/python3.8/site-packages/torch/nn/parallel/parallel_apply.py", line 64, in _worker
    output = module(*input, **kwargs)
  File "/mnt/data/envs/grace-test/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "autoencoder_train.py", line 40, in forward
    strings = self.entropy_z.compress(y) # .compress(y) quantizes/compresses the latent feature y, and returns a string of bytes
  File "/mnt/data/envs/grace-test/lib/python3.8/site-packages/compressai/entropy_models/entropy_models.py", line 541, in compress
    return super().compress(x, indexes, medians)
  File "/mnt/data/envs/grace-test/lib/python3.8/site-packages/compressai/entropy_models/entropy_models.py", line 254, in compress
    self._check_cdf_size()
  File "/mnt/data/envs/grace-test/lib/python3.8/site-packages/compressai/entropy_models/entropy_models.py", line 216, in _check_cdf_size
    raise ValueError("Uninitialized CDFs. Run update() first")
ValueError: Uninitialized CDFs. Run update() first

Any help/suggestions? Thank you!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions