Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 73 additions & 7 deletions src/xtc/runtimes/accelerator/gpu/GPUDevice.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,16 @@ def __init_once__(self):
f"{get_mlir_prefix()}/lib/{cuda_runtime_lib}"
)
self.loaded_kernels: dict[Module, LibLoader] = {}
create_stream_func_name = "mgpuStreamCreate"
create_stream_func = getattr(
self._mlir_runtime_lib.lib, create_stream_func_name
)
assert create_stream_func is not None, (
f"Cannot find symbol {create_stream_func_name} in lib {self._mlir_runtime_lib.lib}"
)
create_stream_func.argtypes = []
create_stream_func.restype = ctypes.c_voidp
self._custream = create_stream_func()

def __get_runtime_func(self, name: str) -> Callable:
if name in runtime_funcs:
Expand Down Expand Up @@ -120,33 +130,89 @@ def unload_module(self, module: Module) -> None:

@override
def memory_allocate(self, size_bytes: int) -> Any:
raise NotImplementedError("memory_allocate is not implemented for GPU device")
func_name = "mgpuMemAlloc"
func = getattr(self._mlir_runtime_lib.lib, func_name)
assert func is not None, (
f"Cannot find symbol {func_name} in lib {self._mlir_runtime_lib.lib}"
)
func.argtypes = [ctypes.c_uint64, ctypes.c_voidp, ctypes.c_bool]
func.restype = ctypes.c_voidp
return func(size_bytes, self._custream, True)

@override
def memory_free(self, handle: Any) -> None:
raise NotImplementedError("memory_free is not implemented for GPU device")
func_name = "mgpuMemFree"
func = getattr(self._mlir_runtime_lib.lib, func_name)
assert func is not None, (
f"Cannot find symbol {func_name} in lib {self._mlir_runtime_lib.lib}"
)
func.argtypes = [ctypes.c_voidp, ctypes.c_voidp]
func.restype = None
func(handle, self._custream)

@override
def memory_copy_to(
self, acc_handle: Any, src: ctypes.c_void_p, size_bytes: int
) -> None:
raise NotImplementedError("memory_copy_to is not implemented for GPU device")
# Copy memory to accelerator device
func_name = "mgpuMemcpy"
func = getattr(self._mlir_runtime_lib.lib, func_name)
assert func is not None, (
f"Cannot find symbol {func_name} in lib {self._mlir_runtime_lib.lib}"
)
func.argtypes = [
ctypes.c_voidp,
ctypes.c_voidp,
ctypes.c_uint64,
ctypes.c_voidp,
]
func.restype = None
func(acc_handle, src, size_bytes, self._custream)
# Synchronize stream
sync_stream_func_name = "mgpuStreamSynchronize"
sync_stream_func = getattr(self._mlir_runtime_lib.lib, sync_stream_func_name)
assert sync_stream_func is not None, (
f"Cannot find symbol {sync_stream_func_name} in lib {self._mlir_runtime_lib.lib}"
)
sync_stream_func.argtypes = [ctypes.c_voidp]
sync_stream_func.restype = None
sync_stream_func(self._custream)

@override
def memory_copy_from(
self, acc_handle: Any, dst: ctypes.c_void_p, size_bytes: int
) -> None:
raise NotImplementedError("memory_copy_from is not implemented for GPU device")
# Copy memory from accelerator device to host
func_name = "mgpuMemcpy"
func = getattr(self._mlir_runtime_lib.lib, func_name)
assert func is not None, (
f"Cannot find symbol {func_name} in lib {self._mlir_runtime_lib.lib}"
)
func.argtypes = [
ctypes.c_voidp,
ctypes.c_voidp,
ctypes.c_uint64,
ctypes.c_voidp,
]
func.restype = None
func(dst, acc_handle, size_bytes, self._custream)
# Synchronize stream
sync_stream_func_name = "mgpuStreamSynchronize"
sync_stream_func = getattr(self._mlir_runtime_lib.lib, sync_stream_func_name)
assert sync_stream_func is not None, (
f"Cannot find symbol {sync_stream_func_name} in lib {self._mlir_runtime_lib.lib}"
)
sync_stream_func.argtypes = [ctypes.c_voidp]
sync_stream_func.restype = None
sync_stream_func(self._custream)

@override
def memory_fill_zero(self, acc_handle: Any, size_bytes: int) -> None:
raise NotImplementedError("memory_fill_zero is not implemented for GPU device")

@override
def memory_data_pointer(self, acc_handle: Any) -> ctypes.c_void_p:
raise NotImplementedError(
"memory_data_pointer is not implemented for GPU device"
)
return ctypes.cast(acc_handle, ctypes.c_void_p)

@override
def evaluate(
Expand Down
7 changes: 7 additions & 0 deletions src/xtc/runtimes/accelerator/gpu/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
#
# SPDX-License-Identifier: BSD-3-Clause
# Copyright (c) 2024-2026 The XTC Project Authors
#
from .GPUDevice import GPUDevice

__all__ = ["GPUDevice"]
22 changes: 16 additions & 6 deletions src/xtc/targets/accelerator/gpu/GPUEvaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,16 @@ def evaluate(self) -> tuple[list[float], int, str]:

# Map the buffers
# TODO Replace memory mapping of buffers by explicit transfers
for buffer in parameters[0] + parameters[1]:
self._device._register_buffer(
buffer.data, buffer.size * buffer.dtype.itemsize
)
for i, buffer in enumerate(parameters[0]):
if self._np_inputs_spec()[i]["device"] is None:
self._device._register_buffer(
Comment thread
guillon marked this conversation as resolved.
buffer.data, buffer.size * buffer.dtype.itemsize
)
for i, buffer in enumerate(parameters[1]):
if self._np_outputs_spec()[i]["device"] is None:
self._device._register_buffer(
buffer.data, buffer.size * buffer.dtype.itemsize
)

# Check the correctness of the outputs
if self._validate:
Expand All @@ -89,8 +95,12 @@ def evaluate(self) -> tuple[list[float], int, str]:
)

# Unmap the buffers
for buffer in parameters[0] + parameters[1]:
self._device._unregister_buffer(buffer.data)
for i, buffer in enumerate(parameters[0]):
if self._np_inputs_spec()[i]["device"] is None:
self._device._unregister_buffer(buffer.data)
for i, buffer in enumerate(parameters[1]):
if self._np_outputs_spec()[i]["device"] is None:
self._device._unregister_buffer(buffer.data)

# Unload the module
self._device.unload_module(self._module)
Expand Down
161 changes: 161 additions & 0 deletions tests/filecheck/backends/target_gpu/test_matmul_mlir_offload_tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
# RUN: python %s 2>&1 | filecheck %s
# REQUIRES: mlir-target=nvgpu

import xtc.graphs.xtc.op as O
from xtc.backends.mlir.MlirGraphBackend import MlirGraphBackend as Backend

from xtc.runtimes.accelerator.gpu import GPUDevice

# Create device
gpu = GPUDevice()

I, J, K, dtype = 4, 32, 512, "float32"
a = O.tensor((I, K), dtype, name="A") # A lives on the host
b = O.tensor((K, J), dtype, name="B", device=gpu) # B lives on the accelerator

with O.graph(name="matmul") as gb:
O.matmul(a, b, name="C", device=gpu) # C must live on the accelerator

graph = gb.graph
print(graph)

impl = Backend(graph)

sch = impl.get_scheduler()
sch.tile("i", {"i1": 2})
sch.tile("j", {"j1": 16})
sch.unroll({"i1": 2})
sch.parallelize(["i"])
sched = sch.schedule()

comp = impl.get_compiler(
target=gpu,
shared_lib=True,
dump_file="gpu_matmul_mlir_offload_tensor",
print_source_ir=True,
print_transformed_ir=True,
)
module = comp.compile(sched)
executor = module.get_executor(validate=True)
res = executor.execute()
print(f"CODE: {res}")
# CHECK: // -----// IR Dump Before transform //----- //
# CHECK-NEXT: module attributes {transform.with_named_sequence} {
# CHECK-NEXT: func.func @matmul(%arg0: memref<4x512xf32> {llvm.noalias}, %arg1: memref<512x32xf32> {llvm.noalias, memref.on_device}, %arg2: memref<4x32xf32> {llvm.noalias, memref.on_device}) {
# CHECK-NEXT: %cst = arith.constant 0.000000e+00 : f32
# CHECK-NEXT: linalg.fill {__xtc_id_C_0_} ins(%cst : f32) outs(%arg2 : memref<4x32xf32>)
# CHECK-NEXT: linalg.matmul {__xtc_id_C_} ins(%arg0, %arg1 : memref<4x512xf32>, memref<512x32xf32>) outs(%arg2 : memref<4x32xf32>)
# CHECK-NEXT: return
# CHECK-NEXT: }
# CHECK-NEXT: transform.named_sequence @_vecto(%arg0: !transform.any_op {transform.consumed}) {
# CHECK-NEXT: transform.structured.vectorize %arg0 : !transform.any_op
# CHECK-NEXT: transform.yield
# CHECK-NEXT: }
# CHECK-NEXT: transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
# CHECK-NEXT: %0 = transform.structured.match attributes {__xtc_id_C_0_} in %arg0 : (!transform.any_op) -> !transform.any_op
# CHECK-NEXT: %tiled_linalg_op, %loops = transform.structured.tile_using_for %0 tile_sizes [1, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
# CHECK-NEXT: transform.annotate %loops "./i" : !transform.any_op
# CHECK-NEXT: %tiled_linalg_op_0, %loops_1 = transform.structured.tile_using_for %tiled_linalg_op tile_sizes [0, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
# CHECK-NEXT: transform.annotate %loops_1 "./j" : !transform.any_op
# CHECK-NEXT: %1 = transform.structured.match attributes {__xtc_id_C_} in %arg0 : (!transform.any_op) -> !transform.any_op
# CHECK-NEXT: %tiled_op, %forall_op = transform.structured.tile_using_forall %1 tile_sizes [2, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
# CHECK-NEXT: transform.annotate %forall_op "./i" : !transform.any_op
# CHECK-NEXT: %tiled_linalg_op_2, %loops_3 = transform.structured.tile_using_for %tiled_op tile_sizes [0, 16, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
# CHECK-NEXT: transform.annotate %loops_3 "./j" : !transform.any_op
# CHECK-NEXT: %tiled_linalg_op_4, %loops_5 = transform.structured.tile_using_for %tiled_linalg_op_2 tile_sizes [0, 0, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
# CHECK-NEXT: transform.annotate %loops_5 "./k" : !transform.any_op
# CHECK-NEXT: %tiled_linalg_op_6, %loops_7 = transform.structured.tile_using_for %tiled_linalg_op_4 tile_sizes [1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
# CHECK-NEXT: transform.annotate %loops_7 "./i1" : !transform.any_op
# CHECK-NEXT: %tiled_linalg_op_8, %loops_9 = transform.structured.tile_using_for %tiled_linalg_op_6 tile_sizes [0, 1, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
# CHECK-NEXT: transform.annotate %loops_9 "./j1" : !transform.any_op
# CHECK-NEXT: transform.loop.unroll %loops_7 {factor = 2 : i64} : !transform.any_op
# CHECK-NEXT: transform.yield
# CHECK-NEXT: }
# CHECK-NEXT: }
# CHECK-NEXT:
# CHECK-NEXT: // -----// IR Dump After transform //----- //
# CHECK-NEXT: #map = affine_map<(d0) -> (d0 * 2)>
# CHECK-NEXT: module attributes {transform.with_named_sequence} {
# CHECK-NEXT: func.func @matmul(%arg0: memref<4x512xf32> {llvm.noalias}, %arg1: memref<512x32xf32> {llvm.noalias, memref.on_device}, %arg2: memref<4x32xf32> {llvm.noalias, memref.on_device}) {
# CHECK-NEXT: %cst = arith.constant 0.000000e+00 : f32
# CHECK-NEXT: %c0 = arith.constant 0 : index
# CHECK-NEXT: %c4 = arith.constant 4 : index
# CHECK-NEXT: %c1 = arith.constant 1 : index
# CHECK-NEXT: scf.for %arg3 = %c0 to %c4 step %c1 {
# CHECK-NEXT: %subview = memref.subview %arg2[%arg3, 0] [1, 32] [1, 1] : memref<4x32xf32> to memref<1x32xf32, strided<[32, 1], offset: ?>>
# CHECK-NEXT: %c0_0 = arith.constant 0 : index
# CHECK-NEXT: %c32 = arith.constant 32 : index
# CHECK-NEXT: %c1_1 = arith.constant 1 : index
# CHECK-NEXT: scf.for %arg4 = %c0_0 to %c32 step %c1_1 {
# CHECK-NEXT: %subview_2 = memref.subview %subview[0, %arg4] [1, 1] [1, 1] : memref<1x32xf32, strided<[32, 1], offset: ?>> to memref<1x1xf32, strided<[32, 1], offset: ?>>
# CHECK-NEXT: linalg.fill {__xtc_id_C_0_} ins(%cst : f32) outs(%subview_2 : memref<1x1xf32, strided<[32, 1], offset: ?>>)
# CHECK-NEXT: } {"./j"}
# CHECK-NEXT: } {"./i"}
# CHECK-NEXT: scf.forall (%arg3) in (2) {
# CHECK-NEXT: %0 = affine.apply #map(%arg3)
# CHECK-NEXT: %subview = memref.subview %arg0[%0, 0] [2, 512] [1, 1] : memref<4x512xf32> to memref<2x512xf32, strided<[512, 1], offset: ?>>
# CHECK-NEXT: %subview_0 = memref.subview %arg1[0, 0] [512, 32] [1, 1] : memref<512x32xf32> to memref<512x32xf32, strided<[32, 1]>>
# CHECK-NEXT: %subview_1 = memref.subview %arg2[%0, 0] [2, 32] [1, 1] : memref<4x32xf32> to memref<2x32xf32, strided<[32, 1], offset: ?>>
# CHECK-NEXT: %c0_2 = arith.constant 0 : index
# CHECK-NEXT: %c32 = arith.constant 32 : index
# CHECK-NEXT: %c16 = arith.constant 16 : index
# CHECK-NEXT: scf.for %arg4 = %c0_2 to %c32 step %c16 {
# CHECK-NEXT: %subview_3 = memref.subview %subview[0, 0] [2, 512] [1, 1] : memref<2x512xf32, strided<[512, 1], offset: ?>> to memref<2x512xf32, strided<[512, 1], offset: ?>>
# CHECK-NEXT: %subview_4 = memref.subview %subview_0[0, %arg4] [512, 16] [1, 1] : memref<512x32xf32, strided<[32, 1]>> to memref<512x16xf32, strided<[32, 1], offset: ?>>
# CHECK-NEXT: %subview_5 = memref.subview %subview_1[0, %arg4] [2, 16] [1, 1] : memref<2x32xf32, strided<[32, 1], offset: ?>> to memref<2x16xf32, strided<[32, 1], offset: ?>>
# CHECK-NEXT: %c0_6 = arith.constant 0 : index
# CHECK-NEXT: %c512 = arith.constant 512 : index
# CHECK-NEXT: %c1_7 = arith.constant 1 : index
# CHECK-NEXT: scf.for %arg5 = %c0_6 to %c512 step %c1_7 {
# CHECK-NEXT: %subview_8 = memref.subview %subview_3[0, %arg5] [2, 1] [1, 1] : memref<2x512xf32, strided<[512, 1], offset: ?>> to memref<2x1xf32, strided<[512, 1], offset: ?>>
# CHECK-NEXT: %subview_9 = memref.subview %subview_4[%arg5, 0] [1, 16] [1, 1] : memref<512x16xf32, strided<[32, 1], offset: ?>> to memref<1x16xf32, strided<[32, 1], offset: ?>>
# CHECK-NEXT: %subview_10 = memref.subview %subview_5[0, 0] [2, 16] [1, 1] : memref<2x16xf32, strided<[32, 1], offset: ?>> to memref<2x16xf32, strided<[32, 1], offset: ?>>
# CHECK-NEXT: %c0_11 = arith.constant 0 : index
# CHECK-NEXT: %c2 = arith.constant 2 : index
# CHECK-NEXT: %c1_12 = arith.constant 1 : index
# CHECK-NEXT: %c2_13 = arith.constant 2 : index
# CHECK-NEXT: %subview_14 = memref.subview %subview_8[%c0_11, 0] [1, 1] [1, 1] : memref<2x1xf32, strided<[512, 1], offset: ?>> to memref<1x1xf32, strided<[512, 1], offset: ?>>
# CHECK-NEXT: %subview_15 = memref.subview %subview_9[0, 0] [1, 16] [1, 1] : memref<1x16xf32, strided<[32, 1], offset: ?>> to memref<1x16xf32, strided<[32, 1], offset: ?>>
# CHECK-NEXT: %subview_16 = memref.subview %subview_10[%c0_11, 0] [1, 16] [1, 1] : memref<2x16xf32, strided<[32, 1], offset: ?>> to memref<1x16xf32, strided<[32, 1], offset: ?>>
# CHECK-NEXT: %c0_17 = arith.constant 0 : index
# CHECK-NEXT: %c16_18 = arith.constant 16 : index
# CHECK-NEXT: %c1_19 = arith.constant 1 : index
# CHECK-NEXT: scf.for %arg6 = %c0_17 to %c16_18 step %c1_19 {
# CHECK-NEXT: %subview_27 = memref.subview %subview_14[0, 0] [1, 1] [1, 1] : memref<1x1xf32, strided<[512, 1], offset: ?>> to memref<1x1xf32, strided<[512, 1], offset: ?>>
# CHECK-NEXT: %subview_28 = memref.subview %subview_15[0, %arg6] [1, 1] [1, 1] : memref<1x16xf32, strided<[32, 1], offset: ?>> to memref<1x1xf32, strided<[32, 1], offset: ?>>
# CHECK-NEXT: %subview_29 = memref.subview %subview_16[0, %arg6] [1, 1] [1, 1] : memref<1x16xf32, strided<[32, 1], offset: ?>> to memref<1x1xf32, strided<[32, 1], offset: ?>>
# CHECK-NEXT: linalg.matmul {__xtc_id_C_} ins(%subview_27, %subview_28 : memref<1x1xf32, strided<[512, 1], offset: ?>>, memref<1x1xf32, strided<[32, 1], offset: ?>>) outs(%subview_29 : memref<1x1xf32, strided<[32, 1], offset: ?>>)
# CHECK-NEXT: } {"./j1"}
# CHECK-NEXT: %c1_20 = arith.constant 1 : index
# CHECK-NEXT: %1 = arith.muli %c1_12, %c1_20 : index
# CHECK-NEXT: %2 = arith.addi %c0_11, %1 : index
# CHECK-NEXT: %subview_21 = memref.subview %subview_8[%2, 0] [1, 1] [1, 1] : memref<2x1xf32, strided<[512, 1], offset: ?>> to memref<1x1xf32, strided<[512, 1], offset: ?>>
# CHECK-NEXT: %subview_22 = memref.subview %subview_9[0, 0] [1, 16] [1, 1] : memref<1x16xf32, strided<[32, 1], offset: ?>> to memref<1x16xf32, strided<[32, 1], offset: ?>>
# CHECK-NEXT: %subview_23 = memref.subview %subview_10[%2, 0] [1, 16] [1, 1] : memref<2x16xf32, strided<[32, 1], offset: ?>> to memref<1x16xf32, strided<[32, 1], offset: ?>>
# CHECK-NEXT: %c0_24 = arith.constant 0 : index
# CHECK-NEXT: %c16_25 = arith.constant 16 : index
# CHECK-NEXT: %c1_26 = arith.constant 1 : index
# CHECK-NEXT: scf.for %arg6 = %c0_24 to %c16_25 step %c1_26 {
# CHECK-NEXT: %subview_27 = memref.subview %subview_21[0, 0] [1, 1] [1, 1] : memref<1x1xf32, strided<[512, 1], offset: ?>> to memref<1x1xf32, strided<[512, 1], offset: ?>>
# CHECK-NEXT: %subview_28 = memref.subview %subview_22[0, %arg6] [1, 1] [1, 1] : memref<1x16xf32, strided<[32, 1], offset: ?>> to memref<1x1xf32, strided<[32, 1], offset: ?>>
# CHECK-NEXT: %subview_29 = memref.subview %subview_23[0, %arg6] [1, 1] [1, 1] : memref<1x16xf32, strided<[32, 1], offset: ?>> to memref<1x1xf32, strided<[32, 1], offset: ?>>
# CHECK-NEXT: linalg.matmul {__xtc_id_C_} ins(%subview_27, %subview_28 : memref<1x1xf32, strided<[512, 1], offset: ?>>, memref<1x1xf32, strided<[32, 1], offset: ?>>) outs(%subview_29 : memref<1x1xf32, strided<[32, 1], offset: ?>>)
# CHECK-NEXT: } {"./j1"}
# CHECK-NEXT: } {"./k"}
# CHECK-NEXT: } {"./j"}
# CHECK-NEXT: } {"./i"}
# CHECK-NEXT: return
# CHECK-NEXT: }
# CHECK-NEXT: }
# CHECK-NEXT:
# CHECK-NEXT: graph:
# CHECK-NEXT: name: matmul
# CHECK-NEXT: inputs:
# CHECK-NEXT: - %0 : 4x512xfloat32
# CHECK-NEXT: - %1 : 512x32xfloat32
# CHECK-NEXT: outputs:
# CHECK-NEXT: - %2 : 4x32xfloat32
# CHECK-NEXT: nodes:
# CHECK-NEXT: - %2: matmul(%0, %1) {name = 'C'} : [4x512xfloat32, 512x32xfloat32] -> [4x32xfloat32]
# CHECK-NEXT:
# CHECK-NEXT: CODE: 0
Loading