Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 9 additions & 24 deletions compressai/entropy_models/entropy_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,7 @@ def __init__(self, method):

if method not in available_entropy_coders():
methods = ", ".join(available_entropy_coders())
raise ValueError(
f'Unknown entropy coder "{method}"' f" (available: {methods})"
)
raise ValueError(f'Unknown entropy coder "{method}" (available: {methods})')

if method == "ans":
from compressai import ans
Expand Down Expand Up @@ -474,28 +472,18 @@ def forward(
if training is None:
training = self.training

if not torch.jit.is_scripting():
# x from B x C x ... to C x B x ...
perm = torch.cat(
(
torch.tensor([1, 0], dtype=torch.long, device=x.device),
torch.arange(2, x.ndim, dtype=torch.long, device=x.device),
)
)
inv_perm = perm
else:
raise NotImplementedError()
# TorchScript in 2D for static inference
# Convert to (channels, ... , batch) format
# perm = (1, 2, 3, 0)
# inv_perm = (3, 0, 1, 2)
D = x.dim()
# B C ... -> C B ...
perm = [1, 0] + list(range(2, D))
inv_perm = [0] * D
for i, p in enumerate(perm):
inv_perm[p] = i

x = x.permute(*perm).contiguous()
shape = x.size()
values = x.reshape(x.size(0), 1, -1)

# Add noise or quantize

outputs = self.quantize(
values, "noise" if training else "dequantize", self._get_medians()
)
Expand All @@ -510,11 +498,8 @@ def forward(
# likelihood = torch.zeros_like(outputs)

# Convert back to input tensor shape
outputs = outputs.reshape(shape)
outputs = outputs.permute(*inv_perm).contiguous()

likelihood = likelihood.reshape(shape)
likelihood = likelihood.permute(*inv_perm).contiguous()
outputs = outputs.reshape(shape).permute(*inv_perm).contiguous()
likelihood = likelihood.reshape(shape).permute(*inv_perm).contiguous()

return outputs, likelihood

Expand Down