Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
206 changes: 206 additions & 0 deletions py/torch_tensorrt/dynamo/lowering/passes/_FakeTensorUpdater.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
import contextlib
import operator
from collections import defaultdict
from typing import Any, Optional

import sympy
import torch
import torch.fx
from torch._dispatch.python import enable_python_dispatcher
from torch._inductor.fx_utils import get_fake_args_kwargs, get_node_storage, get_storage
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.fx.experimental.symbolic_shapes import (
compute_unbacked_bindings,
rebind_unbacked,
statically_known_true,
sym_eq,
)
from torch.utils._ordered_set import OrderedSet


# Adapted from torch._inductor.fx_utils.FakeTensorUpdater
class FakeTensorUpdater:
"""
The main idea here is that it's difficult to maintain accurate fake
tensors (our primary form of metadata) for each node in our graph as we
transform it.
The most reliable way to obtain this information is by rerunning
faketensor propagation. However, in general, faketensor propagation is
fairly expensive. So, instead we'd like to only rerun faketensor
propagation on nodes that have changed.
In order to detect which nodes have changed, we first hash its node,
target, and argument lists (which are immutable in FX).
Then, whenever we call incremental_update, we check which FX nodes have a
new hash, and recompute the faketensor metadata for that node. Then, we
continue to recursively compute the faketensors for all users until the
fake tensors stop changing.
"""

def __init__(self, graph: torch.fx.Graph) -> None:
self.processed_hashes = OrderedSet[Any]()
self.graph = graph

for node in self.graph.nodes:
self.processed_hashes.add(self.hash_node(node))

def hash_node(self, node: torch.fx.Node) -> tuple[torch.fx.Node, Any, Any, Any]:
return (node, node.target, id(node.args), id(node.kwargs))

def incremental_update(self, fake_mode: FakeTensorMode) -> None:
"""Update FakeTensors on self.graph. We will try to do the minimum amount of work."""
existing_storages: defaultdict[Optional[int], int] = defaultdict(int)
for node in self.graph.nodes:
existing_storages[get_node_storage(node)] += 1

def is_intlist_same(new: Any, old: Any) -> Any:
return statically_known_true(sym_eq(new, old))

def is_fake_tensor_same(new: Any, old: Any, *, node: torch.fx.Node) -> Any:
if type(new) is not type(old):
return False
if isinstance(new, (list, tuple)):
if len(new) != len(old):
return False
return all(
is_fake_tensor_same(new_i, old_i, node=node)
for new_i, old_i in zip(new, old)
)
if new is None:
return old is None
if not isinstance(new, torch.Tensor):
assert isinstance(
new, (torch.SymInt, torch.SymBool, torch.SymFloat)
), f"Unknown type {type(new)} in {self.graph}"
return (
new.node.shape_env._maybe_evaluate_static(
sympy.Eq(new.node.expr, old.node.expr)
)
== sympy.true
)
if not is_intlist_same(new.shape, old.shape) or new.layout != old.layout:
return False
if new.layout == torch.strided and (
not is_intlist_same(new.stride(), old.stride())
or not statically_known_true(
new.storage_offset() == old.storage_offset()
)
):
return False

if new.device != old.device:
return False

if get_storage(new) == get_storage(old):
return True

def any_user_may_alias(node: torch.fx.Node) -> bool:
if not isinstance(node.meta["val"], torch.Tensor):
# analysis too complicated on lists, can support in the future
return True
for user in node.users:
if not (
isinstance(
user.target,
(torch._ops.OpOverload, torch._ops.HigherOrderOperator),
)
):
return True
if isinstance(user.target, torch._ops.HigherOrderOperator):
# HOPs that survive until inductor are all non-aliasing HOPs.
# We will likely never support HOPs that are aliasing.
continue
# Strategy: do a FakeTensor prop, see if the storage aliases.
# If Inductor ever gets tighter invariants on OpOverloads
# (that is, we ban things like torch.ops.aten.reshape calls in the graph),
# Then this could just be a fast schema lookup.
is_valid, args, kwargs = get_fake_args_kwargs(user)
if not is_valid:
return True
with (
fake_mode,
enable_python_dispatcher(),
contextlib.ExitStack() as stack,
):
# Ignore unbacked symbols (if they exist): we're making
# this FakeTensor and then throwing it away.
if fake_mode.shape_env is not None:
stack.enter_context(
fake_mode.shape_env.ignore_fresh_unbacked_symbols()
)
new_fake_tensor = user.target(*args, **kwargs)
if not isinstance(new_fake_tensor, torch.Tensor):
# analysis too complicated on lists, can support in the future
return True
if get_storage(new_fake_tensor) == get_storage(node.meta["val"]):
return True
return False

# This is the case where it returns a completely fresh storage that's used nowhere else.
# If the FakeTensor's storage is fresh and none of the node's users can alias it, then
# we don't need to update this node.
if (
existing_storages[get_storage(old)] == 1
and get_storage(new) not in existing_storages
and not any_user_may_alias(node)
):
return True

return False

def should_process_node(node: torch.fx.Node) -> bool:
# node.target for nodes returning true from this function
# are called under fake mode and does not work for inductor
# lowerings. We check if the node.target is an aten operator
# or operator.getitem which is used when returning multiple
# tensors from an op.
return node.op == "call_function" and (
isinstance(node.target, torch._ops.OpOverload)
or node.target is operator.getitem
or node.target
is torch._inductor.fx_passes.reinplace._generalized_scatter
)

to_process = OrderedSet[int]()
for node in self.graph.nodes:
# NB: Be very careful about skipping nodes (via continues) here
# and ask for a careful review when changing this code. The
# consequence for incorrect FakeTensor metadata is difficult-to-debug
# silent incorrectness.
if (
self.hash_node(node) in self.processed_hashes
and id(node) not in to_process
):
continue

if not should_process_node(node):
continue

is_valid, args, kwargs = get_fake_args_kwargs(node)
if not is_valid:
continue
with fake_mode, enable_python_dispatcher():
new_fake_tensor = node.target(*args, **kwargs)

if "val" in node.meta and is_fake_tensor_same(
new_fake_tensor, node.meta["val"], node=node
):
continue

rebind_unbacked(fake_mode.shape_env, node, new_fake_tensor)

node.meta["val"] = new_fake_tensor
if (shape_env := fake_mode.shape_env) and (
symbol_to_path := compute_unbacked_bindings(shape_env, new_fake_tensor)
):
# Refresh the bindings to the new symbols

node.meta["unbacked_bindings"] = symbol_to_path

existing_storages[get_node_storage(node)] += 1

to_process.update([id(user) for user in node.users])

self.processed_hashes.add(self.hash_node(node))
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch
from torch_tensorrt._utils import is_tegra_platform
from torch_tensorrt.dynamo._settings import CompilationSettings
from torch_tensorrt.dynamo.lowering.passes._FakeTensorUpdater import FakeTensorUpdater
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
trace_intermediate_node_outputs,
)
Expand All @@ -18,6 +19,7 @@
from .remove_input_alias_fixing_clones import remove_input_alias_fixing_clones
from .remove_num_users_is_0_nodes import remove_num_users_is_0_nodes
from .repair_input_as_output import repair_input_as_output
from .replace_fused_rms_norm import replace_fused_rms_norm
from .replace_max_pool_with_indices import replace_max_pool_with_indices
from .rule_based_autocast import rule_based_autocast

Expand All @@ -28,6 +30,7 @@
]

post_lowering_pass_list = [
replace_fused_rms_norm,
remove_input_alias_fixing_clones,
constant_fold,
repair_input_as_output,
Expand Down Expand Up @@ -129,7 +132,10 @@ def post_lowering(
logging.debug(
f"Invoking DynamoPassManager and applying lowering passes: {ATEN_POST_LOWERING_PASSES}"
)
fake_tensor_updater = FakeTensorUpdater(gm.graph)
gm = ATEN_POST_LOWERING_PASSES(gm, settings)
if (fake_mode := torch._export.utils._detect_fake_mode_from_gm(gm)) is not None:
fake_tensor_updater.incremental_update(fake_mode)

return gm

Expand Down
92 changes: 92 additions & 0 deletions py/torch_tensorrt/dynamo/lowering/passes/replace_fused_rms_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import logging
import operator

import torch
from torch_tensorrt.dynamo._settings import CompilationSettings
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
clean_up_graph_after_modifications,
)

logger = logging.getLogger(__name__)


def replace_fused_rms_norm(
gm: torch.fx.GraphModule, settings: CompilationSettings
) -> torch.fx.GraphModule:
"""Replace fused rms norm ops in the graph"""
count = 0
for node in gm.graph.nodes:
if node.target == torch.ops.aten._fused_rms_norm.default:
x_normalized, rsqrt = process_fused_rms_norm_node(node, gm)
count += 1

logger.debug(f"Replaced {count} fused rms norm nodes:\n{gm.graph}")

gm = clean_up_graph_after_modifications(gm)

return gm


def process_fused_rms_norm_node(
node: torch.fx.Node, gm: torch.fx.GraphModule
) -> tuple[torch.fx.Node, torch.fx.Node]:

x, shape, weight, eps = node.args[0], node.args[1], node.args[2], node.args[3]
if eps is None:
eps = 1e-5
# Calculate dimensions to normalize over (similar to layer_norm)
# normalized_shape specifies the last N dimensions
x_dim = len(node.meta["val"][0].shape)
dims_to_reduce = []
for i in range(len(shape)):
dims_to_reduce.append(x_dim - i - 1)

with gm.graph.inserting_before(node):
# Replace fused rms norm with standard rms norm
x_squared = gm.graph.call_function(
torch.ops.aten.mul.Tensor,
args=(x, x),
)

x_squared_sum = gm.graph.call_function(
torch.ops.aten.mean.dim,
args=(x_squared, dims_to_reduce, True),
)

x_squared_sum_eps = gm.graph.call_function(
torch.ops.aten.add.Tensor,
args=(x_squared_sum, eps),
)

x_squared_sum_eps_rsqrt = gm.graph.call_function(
torch.ops.aten.rsqrt.default,
args=(x_squared_sum_eps,),
)

x_normalized = gm.graph.call_function(
torch.ops.aten.mul.Tensor,
args=(x, x_squared_sum_eps_rsqrt),
)

if weight is not None:
x_normalized = gm.graph.call_function(
torch.ops.aten.mul.Tensor,
args=(x_normalized, weight),
)

for i, user in enumerate(list(node.users)):
if user.op == "call_function" and user.target == operator.getitem:
if i == 0:
# If the getitem is extracting the first element (the output tensor)
user.replace_all_uses_with(x_normalized)
else:
user.replace_all_uses_with(x_squared_sum_eps_rsqrt)

logger.debug(
f"Replaced {i}-th user of fused_rms_norm node [{user}] with lowered rms_norm output [{x_normalized if i == 0 else x_squared_sum_eps_rsqrt}]"
)
gm.graph.erase_node(user)

gm.graph.erase_node(node)

return x_normalized, x_squared_sum_eps_rsqrt
44 changes: 44 additions & 0 deletions tests/py/dynamo/lowering/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import pytest
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
DYNAMO_ATEN_CONVERTERS,
DYNAMO_CONVERTERS,
)


@pytest.fixture(autouse=True)
def reset_torch_tensorrt_state():
"""
Ensure test isolation by restoring converter registry state and clearing caches.
This prevents earlier tests from mutating global state (e.g., disallowed targets)
which can cause different partitioning outcomes when running multiple tests.
"""
# Snapshot current global state
original_registry = {k: list(v) for k, v in DYNAMO_ATEN_CONVERTERS.items()}
original_disallowed = set(getattr(DYNAMO_CONVERTERS, "disallowed_targets", set()))
original_settings = getattr(DYNAMO_CONVERTERS, "compilation_settings", None)

try:
yield
finally:
# Restore converter registry
DYNAMO_ATEN_CONVERTERS.clear()
DYNAMO_ATEN_CONVERTERS.update(
{k: list(v) for k, v in original_registry.items()}
)

# Restore disallowed targets and compilation settings
try:
DYNAMO_CONVERTERS.set_disallowed_targets(original_disallowed)
except Exception:
pass
if original_settings is not None:
try:
DYNAMO_CONVERTERS.set_compilation_settings(original_settings)
except Exception:
pass

# Clear caches again to avoid stale state carrying forward
try:
trace_atomic_graph.cache_clear()
except Exception:
pass
Loading
Loading