Skip to content

Conversation

@studyingeugene
Copy link
Contributor

@studyingeugene studyingeugene commented Oct 22, 2025

I treated this as a small bug fix rather than a feature addition, so I submitted a PR directly without an issue.
Apologies if that’s against the usual workflow — I’ll be glad to open an issue if preferred.

What's changed

  • Replace tensor-based permutation construction (torch.tensor, torch.arange, torch.cat) with pure Python list version
  • Add explicit inverse permutation for correctness (inv_perm[p] = i)
  • Remove only the TorchScript-specific branch related to permutation logic (is_scripting() guard used for perm/perm_inv construction)

Why

The previous implementation created small tensors on device each forward call, e.g.:

# Before
perm = torch.cat((torch.tensor([1, 0], device=x.device),
                  torch.arange(2, x.ndim, device=x.device)))
inv_perm = perm

# After 
D = x.dim()
perm = [1, 0] + list(range(2, D))
inv_perm = [0] * D
for i, p in enumerate(perm):
    inv_perm[p] = i
    

Old one causes:

  • Graph breaks in torch.compile() due to dynamic tensor creation
  • Extra CUDA allocations on each call

Changed one improves:

  • Compile stability and graph caching under torch.compile
  • Runtime efficiency (no per-call tensor construction)

Evaluation

Please see the attached test script:
test_script.zip

The script:

  1. Loads identical mbt2018_mean models (model_old, model_new)
  2. Patches only the EntropyBottleneck.forward() of model_new
  3. Runs numerical equivalence torch.allclose and compilation checks torch._dynamo.explain()

Results:

  • Outputs (x_hat, likelihood_y, likelihood_z) show no measurable difference
  • torch.compile() executes successfully with no graph breaks or guard inflation

Below are excerpts from the output logs:

in old_dynamo_log.txt

Graph Count: 6
Graph Break Count: 5
Op Count: 167
in new_dynamo_log.txt

Graph Count: 1
Graph Break Count: 0
Op Count: 164

These results confirm that the refactored version compiles cleanly and runs efficiently without any graph breaks.

in compatibility_test.txt

==== EntropyBottleneck Compatibility Test ====
Device: cuda
Input shape: torch.Size([8, 3, 256, 256])
----------------------------------------------
[x_hat] allclose=True, max_diff=0.000000e+00
[likelihood_y] allclose=True, max_diff=0.000000e+00
[likelihood_y] allclose=True, max_diff=0.000000e+00

All tests completed.

These result confirms that no functional difference exists between the original and refactored implementations.

Addition

Since only the forward() method was modified, all existing parameters and buffers remain valid and can be reused without any reinitialization.

Thanks for reading

I appreciate your time reviewing this change.

…xecution

What's changed
- Replace tensor-based perm construction with list-based version
- Add explicit inverse permutation for correctness
- Remove TorchScript-specific branches

Why
- Compile-friendly: torch.compile/AOTAutograd prefer static Python control flow and index lists over device tensor construction inside forward. Replacing torch.tensor([...]), torch.arange(...), and torch.cat(...) with plain Python lists reduces graph breaks and guard complexity, improving compilation stability and cache reuse.
Fix lint errors in entropy_models.py
@fracape fracape merged commit 4fbc02f into InterDigitalInc:master Oct 23, 2025
8 checks passed
@studyingeugene
Copy link
Contributor Author

@fracape Thank you for reviewing and accepting my pull request!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants