Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions unified-runtime/source/adapters/level_zero/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,13 @@ template <>
ze_structure_type_t getZeStructureType<ze_device_memory_ext_properties_t>() {
return ZE_STRUCTURE_TYPE_DEVICE_MEMORY_EXT_PROPERTIES;
}
#ifdef ZE_DEVICE_USABLEMEM_SIZE_PROPERTIES_EXT_NAME
template <>
ze_structure_type_t
getZeStructureType<ze_device_usablemem_size_ext_properties_t>() {
return ZE_STRUCTURE_TYPE_DEVICE_USABLEMEM_SIZE_EXT_PROPERTIES;
}
#endif
template <>
ze_structure_type_t getZeStructureType<ze_device_ip_version_ext_t>() {
return ZE_STRUCTURE_TYPE_DEVICE_IP_VERSION_EXT;
Expand Down
59 changes: 59 additions & 0 deletions unified-runtime/source/adapters/level_zero/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,46 @@ uint64_t calculateGlobalMemSize(ur_device_handle_t Device) {
return Device->ZeGlobalMemSize.get().value;
}

static bool
supportsDeviceUsableMemSizeExtension(ur_platform_handle_t Platform) {
#ifdef ZE_DEVICE_USABLEMEM_SIZE_PROPERTIES_EXT_NAME
constexpr const char *ExtensionName =
ZE_DEVICE_USABLEMEM_SIZE_PROPERTIES_EXT_NAME;
constexpr uint32_t MinVersion = ZE_MAKE_VERSION(1, 0);

auto Extension = Platform->zeDriverExtensionMap.find(ExtensionName);
return Extension != Platform->zeDriverExtensionMap.end() &&
Extension->second >= MinVersion;
#else
std::ignore = Platform;
return false;
#endif
}

static std::optional<uint64_t>
getDeviceUsableMemSizeFromCore(ur_device_handle_t Device) {
#ifdef ZE_DEVICE_USABLEMEM_SIZE_PROPERTIES_EXT_NAME
if (!supportsDeviceUsableMemSizeExtension(Device->Platform)) {
return std::nullopt;
}

ZeStruct<ze_device_properties_t> DeviceProperties;
ZeStruct<ze_device_usablemem_size_ext_properties_t> UsableMemProperties;
DeviceProperties.pNext = &UsableMemProperties;

auto ZeResult = ZE_CALL_NOCHECK(zeDeviceGetProperties,
(Device->ZeDevice, &DeviceProperties));
if (ZeResult != ZE_RESULT_SUCCESS) {
return std::nullopt;
}

return UsableMemProperties.currUsableMemSize;
#else
std::ignore = Device;
return std::nullopt;
#endif
}

// Return the Sysman device handle and correpsonding data for the given UR
// device.
static std::tuple<zes_device_handle_t, ur_zes_device_handle_data_t, ur_result_t>
Expand Down Expand Up @@ -863,6 +903,21 @@ ur_result_t urDeviceGetInfo(
}

case UR_DEVICE_INFO_GLOBAL_MEM_FREE: {
if (!ParamValue && pSize) {
if (supportsDeviceUsableMemSizeExtension(Device->Platform)) {
return ReturnValue(uint64_t{0});
}

auto [ZesDevice, ZesDeviceData, Result] = getZesDeviceData(Device);
(void)ZesDevice;
(void)ZesDeviceData;
if (Result != UR_RESULT_SUCCESS) {
return Result;
}

return ReturnValue(uint64_t{0});
}

// Calculate the global memory size as the max limit that can be reported as
// "free" memory for the user to allocate.
uint64_t GlobalMemSize = calculateGlobalMemSize(Device);
Expand All @@ -871,6 +926,10 @@ ur_result_t urDeviceGetInfo(
uint64_t FreeMemory = 0;
uint32_t MemCount = 0;

if (auto CoreUsableMemSize = getDeviceUsableMemSizeFromCore(Device)) {
return ReturnValue(std::min(GlobalMemSize, *CoreUsableMemSize));
}

auto [ZesDevice, ZesDeviceData, Result] = getZesDeviceData(Device);
if (Result != UR_RESULT_SUCCESS)
return Result;
Expand Down
Loading