Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,8 +241,19 @@ def expand(
"Cannot expand to shape with rank smaller than original tensor."
)

# After the above padding, the shape and tensor rank must be equal
assert len(input_t.shape) == shape_rank
# After the above padding, the shape and tensor rank must be equal.
# Safely check rank (len(...) may fail on symbolic shapes).
try:
current_rank = len(input_t.shape)
except Exception:
current_rank = initial_tensor_rank

if current_rank != shape_rank:
raise RuntimeError(
f"expand lowering: expected input rank {shape_rank} after padding, but got {current_rank}. "
"This may indicate symbolic or dynamic dimensions causing a rank mismatch."
)


# Configure the start, strides and output shape tensors
start = tuple([0] * shape_rank)
Expand Down
47 changes: 47 additions & 0 deletions tests/dynamo/test_repeat_expand_repro.py
Copy link
Collaborator

@apbose apbose Jan 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR. I don't think the above fix addresses the issue. Since dynamic shapes are already handled in prepend_ones. Also current_rank should now be the shape it is expanded to.

try:
        current_rank = len(input_t.shape)
except Exception:
        current_rank = shape_rank

In the below test case there will be 0 computational nodes that will depend on runtime input, since the shape values will be constant. You could make them dynamic to invoke the converter.

I went through the above issue and looks like the root issue is dims being 10 here which is not permitted in TRT. It can handle till 8 max dims- https://docs.nvidia.com/deeplearning/tensorrt/latest/_static/c-api/classnvinfer1_1_1_dims64.html

Pytorch decomposes repeat to unsqueeze-> expand -> permute -> reshape

But for 5D tensor layer.reshape_dims = new_shape fails here, since DIMS can't be 10 here, and it fails in the dynamic case too in layer.set_input(1, reshape_dim_layer.get_output(0)). Hence input_tensor.shape would come invalid.

The original example would work with only 4 dimensions. Instead of emb_t = self.pos_emb_t[: pe_size[0]][None, :, None, None, :].repeat(batch_size, 1, pe_size[1], pe_size[2], 1) 5 dims here.

WAR would be to replace repeat with expand without broadcasting the dimension or using tile operation. I need to look into this more.

Ideally for tests below we would want the test case to be in https://github.com/pytorch/TensorRT/tree/main/tests/py/dynamo/conversion if it is a converter fix, so below location would not work.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for reviewing this, looks like I misunderstood what was actually failing and ended up fixing the wrong thing. I see the issue much more clearly now. I’ll take another pass at it with the correct context and follow up if I find something useful.

Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import pytest
import torch
import torch.nn as nn

try:
import torch_tensorrt
except Exception:
torch_tensorrt = None

REQUIRES_TRT = torch.cuda.is_available() and (torch_tensorrt is not None)

pytestmark = pytest.mark.skipif(not REQUIRES_TRT, reason="requires CUDA + Torch-TensorRT runtime")

class CosmosLearnablePositionalEmbed(nn.Module):
def __init__(self, hidden_size, max_size, patch_size):
super().__init__()
self.patch_size = patch_size
self.pos_emb_t = nn.Parameter(torch.zeros(max_size[0] // patch_size[0], hidden_size))
self.pos_emb_h = nn.Parameter(torch.zeros(max_size[1] // patch_size[1], hidden_size))
self.pos_emb_w = nn.Parameter(torch.zeros(max_size[2] // patch_size[2], hidden_size))

def forward(self, hidden_states):
batch_size, _, num_frames, height, width = hidden_states.shape
pe_size = [num_frames // self.patch_size[0], height // self.patch_size[1], width // self.patch_size[2]]
emb_t = self.pos_emb_t[:pe_size[0]][None, :, None, None, :].repeat(batch_size, 1, pe_size[1], pe_size[2], 1)
emb_h = self.pos_emb_h[:pe_size[1]][None, None, :, None, :].repeat(batch_size, pe_size[0], 1, pe_size[2], 1)
emb_w = self.pos_emb_w[:pe_size[2]][None, None, None, :, :].repeat(batch_size, pe_size[0], pe_size[1], 1, 1)
emb = emb_t + emb_h + emb_w
emb = emb.flatten(1, 3)
return emb

def test_repeat_expand_lowering_repro():
device = torch.device("cuda")
hidden_size = 4096
model = CosmosLearnablePositionalEmbed(hidden_size=hidden_size, max_size=(128,240,240), patch_size=(1,2,2)).to(device).eval()
hidden_states = torch.randn(1, 17, 16, 88, 160, dtype=torch.bfloat16, device=device)

with torch.no_grad():
pyt_out = model(hidden_states)

ep = torch.export.export(model, args=(hidden_states,), strict=False)
trt_mod = torch_tensorrt.dynamo.compile(ep, inputs=[hidden_states], enabled_precisions={torch.bfloat16}, use_python_runtime=True)
trt_out = trt_mod(hidden_states)

assert pyt_out.shape == trt_out.shape
maxdiff = (pyt_out.float() - trt_out.float()).abs().max().item()
assert maxdiff < 1e-2