Skip to content

Conversation

@apbose
Copy link
Collaborator

@apbose apbose commented Nov 6, 2025

This PR addresses the case of empty tensor in torchTRT based on https://docs.nvidia.com/deeplearning/tensorrt/latest/inference-library/advanced.html#empty-tensors, and also focuses on concat operation edge case involving empty tensor

TODO: Might have to separate the case of concat from this PR, in the case when torch.Tensor([]) and a rank greater tensor are concatenated, which is a valid case in pytorch but not TRT.

This addressed #3865. Corresponding HF transformers issue raised - huggingface/transformers#42027 where empty tensor should not come in the first place

@apbose apbose self-assigned this Nov 6, 2025
@meta-cla meta-cla bot added the cla signed label Nov 6, 2025
@apbose apbose marked this pull request as draft November 6, 2025 21:51
@github-actions github-actions bot added component: tests Issues re: Tests component: conversion Issues re: Conversion stage component: core Issues re: The core compiler component: api [Python] Issues re: Python API component: runtime component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths labels Nov 6, 2025
@github-actions github-actions bot requested a review from peri044 November 6, 2025 21:51
@apbose apbose force-pushed the abose/torchTRT_empty_tensor_handling branch 2 times, most recently from 87ebaf5 to 547022d Compare November 21, 2025 01:06
@apbose apbose marked this pull request as ready for review November 21, 2025 01:06
@apbose apbose force-pushed the abose/torchTRT_empty_tensor_handling branch from 547022d to 5d9d5fc Compare November 25, 2025 05:07
@apbose apbose changed the title Empty tensor handling [WIP] Empty tensor handling Nov 26, 2025
Copy link
Collaborator

@narendasan narendasan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@apbose this maybe be a case where we would want TRT to properly handle this rather than us doing something hacky. Lets raise it with Yuan Yuo. Dummy inputs do not feel like the right solution

auto shape = core::util::toVec(dims);
LOG_DEBUG("Input Name: " << name << " Shape: " << dims);

void* tensor_addr = nullptr;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I strongly want to avoid having nullptr basically anywhere, we should be looking for some sane default

@apbose apbose force-pushed the abose/torchTRT_empty_tensor_handling branch from 47be81b to f411fd1 Compare January 23, 2026 01:04
Copy link
Collaborator

@narendasan narendasan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Its looking good, please add a testcase then should be good to merge

@github-actions github-actions bot added the component: converters Issues re: Specific op converters label Jan 27, 2026
# FX converters (legacy, stored as single function)
# Skip FX if dynamo converter exists but failed validation
# This ensures validator decisions are respected
if dynamo_converter_failed_validation:
Copy link
Collaborator

@narendasan narendasan Jan 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this needed for empty tensor handling? Also are we near the point we can just remove FX converters?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is particularly for the case where x is torch.Tensor([]), and y is torch.Tensor torch.randn(3,4) where the validator comes into play.

# makes use of validator
class ConcatEmptyModelEmptyConstantMisMatchDim(nn.Module):
    def __init__(self, dim=0):
        super().__init__()
        self.dim = dim

    def forward(self, x):
        y = torch.tensor([], device="cuda")
        return torch.cat([x, y], dim=self.dim)

This error comes -

"/code/torchTRT/TensorRT/py/torch_tensorrt/fx/converters/aten_ops_converters.py", line 465, in aten_ops_cat                                                                                                                                                                                                                                  
    return acc_ops_converters.acc_ops_cat(network, target, None, kwargs_new, name) 

acc_ops converter is called since validator fails and it resorts to the fx converter implementation which is next in precedence. Yes we should remove the converters IMO, but till that is not done, this would be required I suppose?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also for the test case, class TestEmptyTensorMemoryLeak(TestCase): should cover the case of no memory leak or memory overload in case of repeated model calls. Let me know if you think any other test case is required

@apbose apbose force-pushed the abose/torchTRT_empty_tensor_handling branch 2 times, most recently from a61974f to 2090079 Compare January 27, 2026 20:51
@apbose apbose force-pushed the abose/torchTRT_empty_tensor_handling branch from 2090079 to 149514a Compare January 28, 2026 03:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed component: api [Python] Issues re: Python API component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: core Issues re: The core compiler component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: runtime component: tests Issues re: Tests

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants