Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Dockerfile.base
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
FROM pytorch/pytorch:2.2.0-cuda12.1-cudnn8-runtime
FROM pytorch/pytorch:2.8.0-cuda12.6-cudnn9-runtime

# Copied from Gem5 Docker file
ENV DEBIAN_FRONTEND=noninteractive
Expand Down
8 changes: 8 additions & 0 deletions PyTorchSimDevice/ExtensionDeviceGuardImpl.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#include "ExtensionDeviceGuardImpl.h"
#include <c10/core/impl/DeviceGuardImplRegistry.h>

namespace c10::extension_device::impl {

C10_REGISTER_GUARD_IMPL(extension_device, ExtensionDeviceGuardImpl);

} // namespace c10::extension_device::impl
127 changes: 127 additions & 0 deletions PyTorchSimDevice/ExtensionDeviceGuardImpl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
#pragma once

#include <c10/core/DeviceGuard.h>
#include <c10/core/impl/DeviceGuardImplInterface.h>
#include <c10/core/Stream.h>
#include <c10/core/Event.h>
#include <c10/core/DeviceType.h>
#include <c10/util/Optional.h>

namespace c10::extension_device::impl {

struct ExtensionDeviceGuardImpl final : public c10::impl::DeviceGuardImplInterface {
static constexpr DeviceType static_type = DeviceType::PrivateUse1; // ✅ your backend type

ExtensionDeviceGuardImpl() = default;

explicit ExtensionDeviceGuardImpl(DeviceType t) {
TORCH_CHECK(
t == static_type,
"ExtensionDeviceGuardImpl initialized with non-extension_device DeviceType: ",
t);
}

// --------------------------------------------------------------------------
// 기본적인 device guard (CPU처럼 동작)
// --------------------------------------------------------------------------
DeviceType type() const override {
return static_type;
}

Device exchangeDevice(Device d) const override {
TORCH_CHECK(d.type() == static_type, "Expected extension_device but got ", d);
return d; // nothing to exchange, CPU-like
}

Device getDevice() const override {
return Device(static_type, 0);
}

void setDevice(Device d) const override {
TORCH_CHECK(d.type() == static_type, "Expected extension_device but got ", d);
}

void uncheckedSetDevice(Device d) const noexcept override {}

DeviceIndex deviceCount() const noexcept override {
return 1; // pretend single device
}

// --------------------------------------------------------------------------
// Stream handling (동기식이므로 기본 stream만 사용)
// --------------------------------------------------------------------------
Stream getStream(Device d) const override {
return Stream(Stream::DEFAULT, d);
}

Stream getNewStream(Device d, int priority = 0) const override {
return Stream(Stream::DEFAULT, d);
}

Stream getStreamFromGlobalPool(Device d, bool = false) const override {
return Stream(Stream::DEFAULT, d);
}

Stream exchangeStream(Stream s) const override {
return s;
}

bool queryStream(const Stream& stream) const override {
(void)stream;
return true;
}

void synchronizeStream(const Stream& stream) const override {
(void)stream;
}

void synchronizeDevice(DeviceIndex device_index) const override {
(void)device_index;
}

// --------------------------------------------------------------------------
// Event handling (전부 no-op)
// --------------------------------------------------------------------------
void destroyEvent(void* event, const DeviceIndex device_index) const noexcept override {
(void)event;
(void)device_index;
}

void record(void** event, const Stream& stream, const DeviceIndex device_index, const EventFlag flag) const override {
(void)event;
(void)stream;
(void)device_index;
(void)flag;
}

void block(void* event, const Stream& stream) const override {
(void)event;
(void)stream;
}

bool queryEvent(void* event) const override {
(void)event;
return true;
}

void synchronizeEvent(void* event) const override {
(void)event;
}

double elapsedTime(void* start_event, void* end_event, const DeviceIndex device_index) const override {
(void)start_event;
(void)end_event;
(void)device_index;
return 0.0;
}

// --------------------------------------------------------------------------
// Misc (allocator integration)
// --------------------------------------------------------------------------
void recordDataPtrOnStream(const c10::DataPtr& data_ptr, const Stream& stream) const override {
(void)data_ptr;
(void)stream;
}
};

} // namespace c10::extension_device::impl
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,12 @@
#include <ATen/NativeFunctions.h>
#include <ATen/native/CPUFallback.h>

#include "ExtensionDeviceGuardImpl.h"

static uint64_t op_counter = 0;
static uint64_t last_saved_value = 0;

// register guard
namespace at {
namespace detail {

C10_REGISTER_GUARD_IMPL(PrivateUse1, c10::impl::NoOpDeviceGuardImpl<DeviceType::PrivateUse1>);

}} // namespace at::detail
C10_REGISTER_GUARD_IMPL(PrivateUse1, c10::extension_device::impl::ExtensionDeviceGuardImpl);

// basic dummy add function
at::Tensor custom_add_Tensor(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha) {
Expand Down Expand Up @@ -113,7 +109,7 @@ at::Tensor custom_to_device(
// A dummy allocator for our custom device, that secretly uses the CPU
struct DummyCustomAllocator final : at::Allocator {
DummyCustomAllocator() = default;
at::DataPtr allocate(size_t nbytes) const override {
at::DataPtr allocate(size_t nbytes) override {
void* data = c10::alloc_cpu(nbytes);
return {data, data, &ReportAndDelete, at::Device(at::DeviceType::PrivateUse1, 0)};
}
Expand All @@ -128,6 +124,10 @@ struct DummyCustomAllocator final : at::Allocator {
at::DeleterFnPtr raw_deleter() const override {
return &ReportAndDelete;
}

void copy_data(void* dest, const void* src, std::size_t count) const override {
std::memcpy(dest, src, count);
}
};

// Register our dummy allocator
Expand Down
63 changes: 63 additions & 0 deletions PyTorchSimDevice/extension_device_interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import torch
from torch._dynamo.device_interface import DeviceInterface, caching_worker_current_devices, caching_worker_device_properties

class _ExtensionDeviceProperties: # FIXME: Dummy property values
name: str = "Extension_device"
platform_name: str
vendor: str
driver_version: str
version: str
max_compute_units: int
gpu_eu_count: int
max_work_group_size: int
max_num_sub_groups: int
sub_group_sizes: list[int]
has_fp16: bool
has_fp64: bool
has_atomic64: bool
has_bfloat16_conversions: bool
has_subgroup_matrix_multiply_accumulate: bool
has_subgroup_matrix_multiply_accumulate_tensor_float32: bool
has_subgroup_2d_block_io: bool
total_memory: int
multi_processor_count: int = 128 # gpu_subslice_count, num_sm
architecture: int
type: str

_ExtensionDeviceProperties = _ExtensionDeviceProperties

class ExtensionDeviceInterface(DeviceInterface):
class Worker:
@staticmethod
def set_device(device: int):
caching_worker_current_devices["extension_device"] = device

@staticmethod
def current_device() -> int:
if "extension_device" in caching_worker_current_devices:
return caching_worker_current_devices["extension_device"]
return torch.xpu.current_device()

@staticmethod
def get_device_properties(device: torch.types.Device = None) -> _ExtensionDeviceProperties:
if device is not None:
if isinstance(device, str):
device = torch.device(device)
assert device.type == "extension_device"
if isinstance(device, torch.device):
device = device.index
if device is None:
device = ExtensionDeviceInterface.Worker.current_device()

if "extension_device" not in caching_worker_device_properties:
device_prop = [
torch.cuda.get_device_properties(i)
for i in range(torch.cuda.device_count())
]
caching_worker_device_properties["extension_device"] = device_prop

return _ExtensionDeviceProperties

@staticmethod
def get_compute_capability(device: torch.types.Device = None):
return 36
25 changes: 25 additions & 0 deletions PyTorchSimDevice/extension_device_op_overrides.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from __future__ import annotations

from textwrap import dedent

from torch._inductor.codegen.common import DeviceOpOverrides, register_device_op_overrides

class ExtensionDeviceOpOverrides(DeviceOpOverrides):
def import_get_raw_stream_as(self, name: str) -> str:
return dedent(
"""
def get_raw_stream(_):
return 0
"""
)

def set_device(self, device_idx: int) -> str:
return "pass"

def synchronize(self) -> str:
return "pass"

def device_guard(self, device_idx: int) -> str:
return "pass"

register_device_op_overrides("extension_device", ExtensionDeviceOpOverrides())
48 changes: 48 additions & 0 deletions PyTorchSimDevice/extension_hooks.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
#include "extension_hooks.h"

bool ExtensionPU1Hooks::isBuilt() const { return true; }
bool ExtensionPU1Hooks::isAvailable() const { return true; }

const at::Generator& ExtensionPU1Hooks::getDefaultGenerator(c10::DeviceIndex idx) const {
if (idx < 0) idx = 0;
static std::vector<at::Generator> gens;
static std::mutex m;
std::lock_guard<std::mutex> g(m);
if (gens.size() <= (size_t)idx) gens.resize((size_t)idx + 1);
if (!gens[idx].defined()) gens[idx] = at::GetGeneratorForPrivateuse1(idx);
return gens[idx]; // 영속 객체 참조 반환
}

at::Generator ExtensionPU1Hooks::getNewGenerator(c10::DeviceIndex idx) const {
if (idx < 0) idx = 0;
return at::GetGeneratorForPrivateuse1(idx);
}

at::Device ExtensionPU1Hooks::getDeviceFromPtr(void* data) const {
return at::Device(at::kPrivateUse1, 0); // MVP: 단일 디바이스 가정
}

bool ExtensionPU1Hooks::isPinnedPtr(const void* data) const {
return false;
}

at::Allocator* ExtensionPU1Hooks::getPinnedMemoryAllocator() const {
return at::getHostAllocator(at::kPrivateUse1);
}

bool ExtensionPU1Hooks::hasPrimaryContext(c10::DeviceIndex device_index) const { return true; }

void ExtensionPU1Hooks::resizePrivateUse1Bytes(const c10::Storage&, size_t) const {
TORCH_CHECK(false, "resizePrivateUse1Bytes not implemented");
}

// REGISTER_EXTENSION_HOOKS(ExtensionPU1Hooks);

namespace {
struct AutoRegistrar {
AutoRegistrar() {
at::RegisterPrivateUse1HooksInterface(new ExtensionPU1Hooks());
}
};
static AutoRegistrar _auto_registrar;
}
30 changes: 30 additions & 0 deletions PyTorchSimDevice/extension_hooks.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#pragma once

#include <ATen/core/CachingHostAllocator.h>
#include <ATen/detail/PrivateUse1HooksInterface.h>

#include <ATen/core/Generator.h>
#include <c10/core/Allocator.h>
#include <c10/core/Device.h>
#include <c10/core/Storage.h>
#include <c10/util/Exception.h>

struct ExtensionPU1Hooks final : public at::PrivateUse1HooksInterface {
ExtensionPU1Hooks() {}
bool isBuilt() const;
bool isAvailable() const;

const at::Generator& getDefaultGenerator(c10::DeviceIndex device_index) const override;

at::Generator getNewGenerator(c10::DeviceIndex device_index = -1) const override;

at::Device getDeviceFromPtr(void* data) const override;

bool isPinnedPtr(const void* data) const override;

at::Allocator* getPinnedMemoryAllocator() const override;

bool hasPrimaryContext(c10::DeviceIndex device_index) const override;

void resizePrivateUse1Bytes(const c10::Storage& /*storage*/, size_t /*newsize*/) const override;
};
3 changes: 2 additions & 1 deletion PyTorchSimFrontend/extension_codecache.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import shlex
import subprocess

from torch._inductor.codecache import AsyncCompile, get_lock_dir, get_hash, write
from torch._inductor.codecache import get_lock_dir, get_hash, write
from torch._inductor.async_compile import AsyncCompile
from AsmParser.tog_generator import tog_generator
from PyTorchSimFrontend.mlir.mlir_caller_codegen import MLIRKernelCallerCodeGen
from PyTorchSimFrontend import extension_config
Expand Down
26 changes: 26 additions & 0 deletions PyTorchSimFrontend/extension_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import sympy
import torch

"""
NOTE: Temporary File

This file contains functions that were removed or changed in newer versions
of PyTorch. It is kept here only to temporarily enable compatibility while
upgrading to PyTorch 2.8 from PyTorch 2.2.

These functions will eventually be integrated into the appropriate source files
or removed once no longer needed.

This file is not intended to be permanent and should be deleted in the future.
"""

def free_symbol_startswith(index: sympy.Expr, prefix: str):
return any(v.name.startswith(prefix) for v in index.free_symbols)

def sympy_symbol(name: str) -> sympy.Symbol:
# This should never be used for creating shape/stride symbols, as those
# should all be allocated before Inductor.
assert name[0] != "s"
# NOTE: shape symbols are positive (> 0), but index variables are only
# non-negative (>= 0).
return sympy.Symbol(name, integer=True, nonnegative=True)
Loading