44import os
55import math
66import torch
7+ from typing import Optional
78from collections import defaultdict
89from concurrent .futures import ThreadPoolExecutor
910from torch ._dynamo .utils import dynamo_timed
1011from torch ._inductor .codegen import cpp , wrapper , common , memory_planning
12+ from torch ._inductor .ir import GraphPartitionSignature
1113from torch ._inductor .virtualized import V , _ops as ops
1214from torch ._inductor .codecache import write_atomic , write
1315from torch ._inductor .utils import (
@@ -75,10 +77,25 @@ def reduction_combine_vec(reduction_type, vector_value, init_value, axis, shape,
7577 return f"vector.multi_reduction <and>, %{ vector_value } , %{ init_value } [{ axis } ] : { shape } to { reduced_shape } "
7678 raise AssertionError (reduction_type )
7779
78- class ExtensionWrapperCodegen (wrapper .WrapperCodeGen ):
80+ class ExtensionWrapperCodegen (wrapper .PythonWrapperCodegen ):
7981 def __init__ (self ):
8082 super ().__init__ ()
8183
84+ @classmethod
85+ def create (
86+ cls ,
87+ is_subgraph : bool ,
88+ subgraph_name : Optional [str ],
89+ parent_wrapper : Optional [wrapper .PythonWrapperCodegen ],
90+ partition_signatures : Optional [GraphPartitionSignature ] = None ,
91+ ):
92+ if is_subgraph :
93+ assert subgraph_name is not None and parent_wrapper is not None
94+ return wrapper .SubgraphPythonWrapperCodegen (
95+ subgraph_name , parent_wrapper , partition_signatures
96+ )
97+ return cls ()
98+
8299 def write_header (self ):
83100 self .header .splice (
84101 f"""
@@ -107,6 +124,7 @@ def write_header(self):
107124 reinterpret_tensor = torch.ops.aten._reinterpret_tensor
108125 custom_async_compile = CustomAsyncCompile()
109126 os.environ["TORCHSIM_LAST_COMPILED_MODULE"] = __file__
127+ print(f\' Wrapper Codegen Path = {{__file__}}\' )
110128 """
111129 )
112130 self .header .splice (
@@ -154,7 +172,7 @@ def call(args):
154172 self .prefix .writeline (f"{ lhs } = args" )
155173 self .prefix .writeline ("args.clear()" )
156174
157- self .codegen_inputs (self . prefix , V . graph . graph_inputs )
175+ self .codegen_inputs ()
158176 self .codegen_input_size_asserts ()
159177 self .codegen_sram_plan_prefix ()
160178
@@ -174,10 +192,27 @@ def codegen_sram_plan_postfix(self, outputs):
174192 continue
175193 self .wrapper_call .writeline (f"sram_plan_postfix('{ name } ', { name } )" )
176194
177- @dynamo_timed
195+ def _generate_kernel_call_helper (
196+ self ,
197+ kernel_name : str ,
198+ call_args ,
199+ * ,
200+ device = None ,
201+ triton = True ,
202+ arg_types = None ,
203+ raw_keys = None ,
204+ raw_args = None ,
205+ triton_meta = None ,
206+ graph_name = "" ,
207+ original_fxnode_name = None ,
208+ ):
209+ device = device or V .graph .get_current_device_or_throw ()
210+ self .writeline (self .wrap_kernel_call (kernel_name , call_args ))
211+ return
212+
178213 def generate (self , is_inference ):
179214 result = IndentedBuffer ()
180- result .splice (self .header )
215+ # result.splice(self.header)
181216
182217 with contextlib .ExitStack () as stack :
183218 stack .enter_context (self .wrapper_call .indent ())
@@ -192,8 +227,13 @@ def generate(self, is_inference):
192227
193228 if isinstance (line , wrapper .MemoryPlanningLine ):
194229 line .codegen (self .wrapper_call )
230+ elif isinstance (line , wrapper .KernelCallLine ):
231+ self .wrapper_call .writeline (self .wrap_kernel_call (line .kernel_name , line .call_args ))
195232 else :
196- self .wrapper_call .writeline (line )
233+ if isinstance (line , wrapper .WrapperLine ):
234+ line .codegen (self .wrapper_call )
235+ else :
236+ self .wrapper_call .writeline (line )
197237 # Add buffer plan hook for alloc
198238 if isinstance (line , memory_planning .AllocFromPoolLine ) or isinstance (line , wrapper .AllocateLine ):
199239 self .wrapper_call .writeline (f"sram_plan_prefix('{ line .node .get_name ()} ', { line .node .get_name ()} )" )
@@ -202,7 +242,9 @@ def generate(self, is_inference):
202242 self .mark_output_type ()
203243 self .generate_return (output_refs )
204244
205- self .append_precomputed_sizes_to_prefix ()
245+ # self.append_precomputed_sizes_to_prefix() # FIXME: Need to replace append_precomputed_sizes_to_prefix()
246+ result .splice (self .header )
247+
206248 self .finalize_prefix ()
207249 result .splice (self .prefix )
208250
@@ -211,7 +253,10 @@ def generate(self, is_inference):
211253
212254 self .generate_end (result )
213255 self .add_benchmark_harness (result )
214- return result .getvaluewithlinemap ()
256+ return (
257+ result .getvaluewithlinemap (),
258+ self .kernel_declarations .getvaluewithlinemap (),
259+ )
215260
216261 def memory_plan (self ):
217262 self .lines = memory_planning .MemoryPlanner (self ).plan (self .lines )
@@ -1494,16 +1539,16 @@ def get_cycle(choice):
14941539 return optimal_src_code
14951540
14961541 def codegen_nodes (self , nodes , kernel_name ):
1497- src_code = super ().codegen_nodes (nodes , kernel_name )
1542+ src_code , meta_code = super ().codegen_nodes (nodes , kernel_name )
14981543 self ._prepare_simulator_headers (src_code )
14991544 if not extension_config .CONFIG_AUTOTUNE or extension_config .CONFIG_BACKENDSIM_SPIKE_ONLY :
1500- return src_code
1545+ return src_code , meta_code
15011546 else :
15021547 optimal_src_code = self .autotune (nodes , kernel_name )
15031548 if optimal_src_code :
1504- return optimal_src_code
1549+ return optimal_src_code , meta_code
15051550 else :
1506- return src_code
1551+ return src_code , meta_code
15071552
15081553 def _prepare_simulator_headers (self , src_code ):
15091554 write_path = extension_codecache .get_write_path (src_code )
0 commit comments