Skip to content

Commit ad13519

Browse files
committed
PyTorch version upgrade: tested on single-operator tests
1 parent 4385e5a commit ad13519

12 files changed

+250
-74
lines changed

PyTorchSimFrontend/extension_codecache.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
import shlex
44
import subprocess
55

6-
from torch._inductor.codecache import AsyncCompile, get_lock_dir, get_hash, write
6+
from torch._inductor.codecache import get_lock_dir, get_hash, write
7+
from torch._inductor.async_compile import AsyncCompile
78
from AsmParser.tog_generator import tog_generator
89
from PyTorchSimFrontend.mlir.mlir_caller_codegen import MLIRKernelCallerCodeGen
910
from PyTorchSimFrontend import extension_config

PyTorchSimFrontend/extension_device.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ at::Tensor custom_to_device(
112112
// A dummy allocator for our custom device, that secretly uses the CPU
113113
struct DummyCustomAllocator final : at::Allocator {
114114
DummyCustomAllocator() = default;
115-
at::DataPtr allocate(size_t nbytes) const override {
115+
at::DataPtr allocate(size_t nbytes) override {
116116
void* data = c10::alloc_cpu(nbytes);
117117
return {data, data, &ReportAndDelete, at::Device(at::DeviceType::PrivateUse1, 0)};
118118
}
@@ -127,6 +127,10 @@ struct DummyCustomAllocator final : at::Allocator {
127127
at::DeleterFnPtr raw_deleter() const override {
128128
return &ReportAndDelete;
129129
}
130+
131+
void copy_data(void* dest, const void* src, std::size_t count) const override {
132+
std::memcpy(dest, src, count);
133+
}
130134
};
131135

132136
// Register our dummy allocator
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import torch
2+
from torch._dynamo.device_interface import DeviceInterface, caching_worker_current_devices, caching_worker_device_properties
3+
4+
class _ExtensionDeviceProperties: # FIXME: Dummy property values
5+
name: str = "Extension_device"
6+
platform_name: str
7+
vendor: str
8+
driver_version: str
9+
version: str
10+
max_compute_units: int
11+
gpu_eu_count: int
12+
max_work_group_size: int
13+
max_num_sub_groups: int
14+
sub_group_sizes: list[int]
15+
has_fp16: bool
16+
has_fp64: bool
17+
has_atomic64: bool
18+
has_bfloat16_conversions: bool
19+
has_subgroup_matrix_multiply_accumulate: bool
20+
has_subgroup_matrix_multiply_accumulate_tensor_float32: bool
21+
has_subgroup_2d_block_io: bool
22+
total_memory: int
23+
multi_processor_count: int = 128 # gpu_subslice_count, num_sm
24+
architecture: int
25+
type: str
26+
27+
_ExtensionDeviceProperties = _ExtensionDeviceProperties
28+
29+
class ExtensionDeviceInterface(DeviceInterface):
30+
class Worker:
31+
@staticmethod
32+
def set_device(device: int):
33+
caching_worker_current_devices["extension_device"] = device
34+
35+
@staticmethod
36+
def current_device() -> int:
37+
if "extension_device" in caching_worker_current_devices:
38+
return caching_worker_current_devices["extension_device"]
39+
return torch.xpu.current_device()
40+
41+
@staticmethod
42+
def get_device_properties(device: torch.types.Device = None) -> _ExtensionDeviceProperties:
43+
if device is not None:
44+
if isinstance(device, str):
45+
device = torch.device(device)
46+
assert device.type == "extension_device"
47+
if isinstance(device, torch.device):
48+
device = device.index
49+
if device is None:
50+
device = ExtensionDeviceInterface.Worker.current_device()
51+
52+
if "extension_device" not in caching_worker_device_properties:
53+
device_prop = [
54+
torch.cuda.get_device_properties(i)
55+
for i in range(torch.cuda.device_count())
56+
]
57+
caching_worker_device_properties["extension_device"] = device_prop
58+
59+
return _ExtensionDeviceProperties
60+
61+
@staticmethod
62+
def get_compute_capability(device: torch.types.Device = None):
63+
return 36
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from __future__ import annotations
2+
3+
from textwrap import dedent
4+
5+
from torch._inductor.codegen.common import DeviceOpOverrides, register_device_op_overrides
6+
7+
class ExtensionDeviceOpOverrides(DeviceOpOverrides):
8+
def import_get_raw_stream_as(self, name: str) -> str:
9+
return dedent(
10+
"""
11+
def get_raw_stream(_):
12+
return 0
13+
"""
14+
)
15+
16+
def set_device(self, device_idx: int) -> str:
17+
return "pass"
18+
19+
def synchronize(self) -> str:
20+
return "pass"
21+
22+
def device_guard(self, device_idx: int) -> str:
23+
return "pass"
24+
25+
register_device_op_overrides("extension_device", ExtensionDeviceOpOverrides())
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import sympy
2+
import torch
3+
4+
"""
5+
NOTE: Temporary File
6+
7+
This file contains functions that were removed or changed in newer versions
8+
of PyTorch. It is kept here only to temporarily enable compatibility while
9+
upgrading to PyTorch 2.8 from PyTorch 2.2.
10+
11+
These functions will eventually be integrated into the appropriate source files
12+
or removed once no longer needed.
13+
14+
This file is not intended to be permanent and should be deleted in the future.
15+
"""
16+
17+
def free_symbol_startswith(index: sympy.Expr, prefix: str):
18+
return any(v.name.startswith(prefix) for v in index.free_symbols)
19+
20+
def sympy_symbol(name: str) -> sympy.Symbol:
21+
# This should never be used for creating shape/stride symbols, as those
22+
# should all be allocated before Inductor.
23+
assert name[0] != "s"
24+
# NOTE: shape symbols are positive (> 0), but index variables are only
25+
# non-negative (>= 0).
26+
return sympy.Symbol(name, integer=True, nonnegative=True)

PyTorchSimFrontend/llvm/llvm_common.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,14 @@
1111
from torch.utils._sympy.value_ranges import ValueRanges
1212

1313
from torch._inductor.utils import (
14-
free_symbol_startswith,
1514
get_sympy_Expr_dtype,
1615
IndentedBuffer,
1716
sympy_subs,
1817
unique,
1918
)
2019

20+
from PyTorchSimFrontend.extension_utils import free_symbol_startswith
21+
2122
schedule_log = torch._logging.getArtifactLogger(__name__, "schedule")
2223

2324
DTYPE_TO_LLVM = {

PyTorchSimFrontend/mlir/mlir_autotune.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ def __init__(
4141
self.extra_args = extra_args
4242
#self.hash_key, self.source_file = CUDACodeCache.write(self.source_code, "so")
4343

44+
def __str__(self) -> str:
45+
return f"{self.kernel_name=}, {self.source_file=}, {self.hash_key=}"
46+
4447
def make_run_fn(
4548
self, input_tensors: torch.Tensor, output_tensors: torch.Tensor
4649
) -> Callable[[], None]:
@@ -62,5 +65,6 @@ def make_run_fn(
6265
*args,
6366
)
6467

65-
def __str__(self) -> str:
66-
return f"{self.kernel_name=}, {self.source_file=}, {self.hash_key=}"
68+
def update_workspace_size(self) -> None:
69+
# FIXME: Not implemented yet. Checkout torch/_inductor/codegen/rocm/rocm_benchmark_request.py
70+
return

PyTorchSimFrontend/mlir/mlir_codegen_backend.py

Lines changed: 56 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,12 @@
44
import os
55
import math
66
import torch
7+
from typing import Optional
78
from collections import defaultdict
89
from concurrent.futures import ThreadPoolExecutor
910
from torch._dynamo.utils import dynamo_timed
1011
from torch._inductor.codegen import cpp, wrapper, common, memory_planning
12+
from torch._inductor.ir import GraphPartitionSignature
1113
from torch._inductor.virtualized import V, _ops as ops
1214
from torch._inductor.codecache import write_atomic, write
1315
from torch._inductor.utils import (
@@ -75,10 +77,25 @@ def reduction_combine_vec(reduction_type, vector_value, init_value, axis, shape,
7577
return f"vector.multi_reduction <and>, %{vector_value}, %{init_value} [{axis}] : {shape} to {reduced_shape}"
7678
raise AssertionError(reduction_type)
7779

78-
class ExtensionWrapperCodegen(wrapper.WrapperCodeGen):
80+
class ExtensionWrapperCodegen(wrapper.PythonWrapperCodegen):
7981
def __init__(self):
8082
super().__init__()
8183

84+
@classmethod
85+
def create(
86+
cls,
87+
is_subgraph: bool,
88+
subgraph_name: Optional[str],
89+
parent_wrapper: Optional[wrapper.PythonWrapperCodegen],
90+
partition_signatures: Optional[GraphPartitionSignature] = None,
91+
):
92+
if is_subgraph:
93+
assert subgraph_name is not None and parent_wrapper is not None
94+
return wrapper.SubgraphPythonWrapperCodegen(
95+
subgraph_name, parent_wrapper, partition_signatures
96+
)
97+
return cls()
98+
8299
def write_header(self):
83100
self.header.splice(
84101
f"""
@@ -107,6 +124,7 @@ def write_header(self):
107124
reinterpret_tensor = torch.ops.aten._reinterpret_tensor
108125
custom_async_compile = CustomAsyncCompile()
109126
os.environ["TORCHSIM_LAST_COMPILED_MODULE"] = __file__
127+
print(f\'Wrapper Codegen Path = {{__file__}}\')
110128
"""
111129
)
112130
self.header.splice(
@@ -154,7 +172,7 @@ def call(args):
154172
self.prefix.writeline(f"{lhs} = args")
155173
self.prefix.writeline("args.clear()")
156174

157-
self.codegen_inputs(self.prefix, V.graph.graph_inputs)
175+
self.codegen_inputs()
158176
self.codegen_input_size_asserts()
159177
self.codegen_sram_plan_prefix()
160178

@@ -174,10 +192,27 @@ def codegen_sram_plan_postfix(self, outputs):
174192
continue
175193
self.wrapper_call.writeline(f"sram_plan_postfix('{name}', {name})")
176194

177-
@dynamo_timed
195+
def _generate_kernel_call_helper(
196+
self,
197+
kernel_name: str,
198+
call_args,
199+
*,
200+
device=None,
201+
triton=True,
202+
arg_types=None,
203+
raw_keys=None,
204+
raw_args=None,
205+
triton_meta=None,
206+
graph_name="",
207+
original_fxnode_name=None,
208+
):
209+
device = device or V.graph.get_current_device_or_throw()
210+
self.writeline(self.wrap_kernel_call(kernel_name, call_args))
211+
return
212+
178213
def generate(self, is_inference):
179214
result = IndentedBuffer()
180-
result.splice(self.header)
215+
# result.splice(self.header)
181216

182217
with contextlib.ExitStack() as stack:
183218
stack.enter_context(self.wrapper_call.indent())
@@ -192,8 +227,13 @@ def generate(self, is_inference):
192227

193228
if isinstance(line, wrapper.MemoryPlanningLine):
194229
line.codegen(self.wrapper_call)
230+
elif isinstance(line, wrapper.KernelCallLine):
231+
self.wrapper_call.writeline(self.wrap_kernel_call(line.kernel_name, line.call_args))
195232
else:
196-
self.wrapper_call.writeline(line)
233+
if isinstance(line, wrapper.WrapperLine):
234+
line.codegen(self.wrapper_call)
235+
else:
236+
self.wrapper_call.writeline(line)
197237
# Add buffer plan hook for alloc
198238
if isinstance(line, memory_planning.AllocFromPoolLine) or isinstance(line, wrapper.AllocateLine):
199239
self.wrapper_call.writeline(f"sram_plan_prefix('{line.node.get_name()}', {line.node.get_name()})")
@@ -202,7 +242,9 @@ def generate(self, is_inference):
202242
self.mark_output_type()
203243
self.generate_return(output_refs)
204244

205-
self.append_precomputed_sizes_to_prefix()
245+
# self.append_precomputed_sizes_to_prefix() # FIXME: Need to replace append_precomputed_sizes_to_prefix()
246+
result.splice(self.header)
247+
206248
self.finalize_prefix()
207249
result.splice(self.prefix)
208250

@@ -211,7 +253,10 @@ def generate(self, is_inference):
211253

212254
self.generate_end(result)
213255
self.add_benchmark_harness(result)
214-
return result.getvaluewithlinemap()
256+
return (
257+
result.getvaluewithlinemap(),
258+
self.kernel_declarations.getvaluewithlinemap(),
259+
)
215260

216261
def memory_plan(self):
217262
self.lines = memory_planning.MemoryPlanner(self).plan(self.lines)
@@ -1494,16 +1539,16 @@ def get_cycle(choice):
14941539
return optimal_src_code
14951540

14961541
def codegen_nodes(self, nodes, kernel_name):
1497-
src_code = super().codegen_nodes(nodes, kernel_name)
1542+
src_code, meta_code = super().codegen_nodes(nodes, kernel_name)
14981543
self._prepare_simulator_headers(src_code)
14991544
if not extension_config.CONFIG_AUTOTUNE or extension_config.CONFIG_BACKENDSIM_SPIKE_ONLY:
1500-
return src_code
1545+
return src_code, meta_code
15011546
else:
15021547
optimal_src_code = self.autotune(nodes, kernel_name)
15031548
if optimal_src_code:
1504-
return optimal_src_code
1549+
return optimal_src_code, meta_code
15051550
else:
1506-
return src_code
1551+
return src_code, meta_code
15071552

15081553
def _prepare_simulator_headers(self, src_code):
15091554
write_path = extension_codecache.get_write_path(src_code)

0 commit comments

Comments
 (0)