|
25 | 25 | _is_reshape = partial(is_op_node, target_op=torch.ops.aten.reshape) |
26 | 26 | _is_zeros_like = partial(is_op_node, target_op=torch.ops.aten.zeros_like) |
27 | 27 |
|
| 28 | +_CONV_TARGETS = { |
| 29 | + torch.ops.aten.conv1d.default, |
| 30 | + torch.ops.aten.conv1d.padding, |
| 31 | + torch.ops.aten.conv2d.default, |
| 32 | + torch.ops.aten.conv2d.padding, |
| 33 | + torch.ops.aten.conv_transpose1d.default, |
| 34 | + torch.ops.aten.conv_transpose2d.input, |
| 35 | +} |
| 36 | + |
| 37 | + |
| 38 | +def _feeds_into_linear(node: Node) -> bool: |
| 39 | + """ |
| 40 | + BFS from node to check if it eventually feeds into a linear op (not conv). |
| 41 | + This is required because: |
| 42 | + - Linear-BN fusion (added by AddSimulatedLinearBatchNormFusionQATPass, NXP-specific) |
| 43 | + - Conv-BN QAT fusion (added by TorchAO's _fuse_conv_bn_qat inside prepare_qat_pt2e) |
| 44 | + are structurally identical. Without this check, we would incorrectly remove |
| 45 | + Conv-BN scale factor chains, breaking Conv-BN QAT fusion when TorchAO's _fold_conv_bn_qat |
| 46 | + is called during convert_pt2e. |
| 47 | + """ |
| 48 | + visited = set() |
| 49 | + queue = list(node.users.keys()) |
| 50 | + while queue: |
| 51 | + n = queue.pop(0) |
| 52 | + if n in visited: |
| 53 | + continue |
| 54 | + visited.add(n) |
| 55 | + if n.op == "call_function": |
| 56 | + if n.target == torch.ops.aten.linear.default: |
| 57 | + return True |
| 58 | + if n.target in _CONV_TARGETS: |
| 59 | + return False |
| 60 | + queue.extend(n.users.keys()) |
| 61 | + return True |
| 62 | + |
28 | 63 |
|
29 | 64 | def _is_denorm_pattern(node: Node) -> bool: |
30 | 65 | if not _is_div(node): |
@@ -56,6 +91,10 @@ def _remove_pattern_from_graph(graph_module: GraphModule, pattern: GraphModule): |
56 | 91 | for match in matches: |
57 | 92 | last_pattern_node = match.anchors[0] |
58 | 93 | last_matched_subgraph_node = match.nodes_map[last_pattern_node] |
| 94 | + |
| 95 | + if not _feeds_into_linear(last_matched_subgraph_node): |
| 96 | + continue |
| 97 | + |
59 | 98 | weight = match.placeholder_nodes[0] |
60 | 99 |
|
61 | 100 | last_matched_subgraph_node.replace_all_uses_with(weight) |
|
0 commit comments