Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 147 additions & 11 deletions test/npu_validation/scripts/generate_testcase.py
Original file line number Diff line number Diff line change
Expand Up @@ -788,6 +788,87 @@ def _resolve_pointer_param(name: str) -> Optional[str]:
return hits


def _detect_prefetch_workspace_pointer_params(text: str, pointer_param_names):
if not pointer_param_names:
return set()

def _is_fully_wrapped_by_parentheses(expr: str) -> bool:
if not (expr.startswith("(") and expr.endswith(")")):
return False
depth = 0
for i, ch in enumerate(expr):
if ch == "(":
depth += 1
elif ch == ")":
depth -= 1
if depth == 0 and i != len(expr) - 1:
return False
return depth == 0

def _extract_identifier(expr: str) -> Optional[str]:
cur = expr.strip()
for _ in range(8):
prev = cur
while _is_fully_wrapped_by_parentheses(cur):
cur = cur[1:-1].strip()

m = re.match(r"^(?:reinterpret_cast|static_cast|const_cast|dynamic_cast)\s*<[^>]+>\s*\((.*)\)$", cur, re.S)
if m:
cur = m.group(1).strip()
continue

m = re.match(r"^\(\s*[^()]+\s*\)\s*(.+)$", cur, re.S)
if m:
cur = m.group(1).strip()
continue

if cur == prev:
break

return cur if re.fullmatch(r"[A-Za-z_]\w*", cur) else None

pointer_set = set(pointer_param_names)
alias = {}
for m in re.finditer(r"\b([A-Za-z_]\w*)\s*=\s*([^;]+);", text):
lhs = m.group(1)
rhs = m.group(2).strip()
src = _extract_identifier(rhs)
if src:
alias[lhs] = src

def _resolve_pointer_param(name: str) -> Optional[str]:
cur = name
seen = set()
for _ in range(12):
if cur in seen:
break
seen.add(cur)
if cur in pointer_set:
return cur
nxt = alias.get(cur)
if not nxt:
return None
cur = nxt
return None

hits = set()
for m in re.finditer(r"\bPrefetchAsyncContext\s+\w+\s*=\s*[^;]*\(([^)]*)\)\s*;", text, re.S):
raw_arg = m.group(1).strip()
arg_name = _extract_identifier(raw_arg)
if not arg_name:
continue
resolved = _resolve_pointer_param(arg_name)
if resolved:
hits.add(resolved)

if not hits:
for name in pointer_param_names:
pat = rf"\bPrefetchAsyncContext\b[^\n;]*\b{re.escape(name)}\b"
if re.search(pat, text):
hits.add(name)
return hits


def _parse_kernel_params(text: str):
match = re.search(r"__global__\s+(?:\w+\s+)*void\s+\w+\s*\(([^)]*)\)", text, re.S)
if not match:
Expand Down Expand Up @@ -1507,6 +1588,12 @@ def generate_testcase(
has_vec_only_section = has_dav_vec and not has_dav_cube

is_mixed_kernel = kernel_info["kind"] == "mixed"
raw_params = kernel_info["raw_params"]
pointer_param_names = [_extract_cpp_name(p) for p in raw_params if _is_gm_pointer_param(p)]
prefetch_workspace_param_names = _detect_prefetch_workspace_pointer_params(
raw_kernel_for_analysis, pointer_param_names
)
uses_prefetch_async_runtime = bool(prefetch_workspace_param_names) and "TPREFETCH_ASYNC(" in raw_kernel_for_analysis

if aicore_arch is None:
if is_mixed_kernel:
Expand Down Expand Up @@ -1557,6 +1644,18 @@ def generate_testcase(
else:
aicore_arch = _infer_aicore_arch(raw_kernel, soc_version)

# TPREFETCH_ASYNC currently uses the A2/A3 SDMA implementation in pto-isa.
# Compiling generated board payloads as dav-c310/REGISTER_BASE on
# Ascend910B3 boards can fault in the UB scratch path. Keep these kernels
# on the A2/A3 compile mode even when the board's SOC string contains 910B.
if uses_prefetch_async_runtime and aicore_arch.startswith("dav-c310"):
if aicore_arch.endswith("-cube"):
aicore_arch = "dav-c220-cube"
elif aicore_arch == "dav-c310":
aicore_arch = "dav-c220"
else:
aicore_arch = "dav-c220-vec"

# For single-section kernels, force-define DAV macro(s) to keep section
# bodies visible to the selected compile arch.
# For mix-kernel arch (dav-c310/dav-c220), do not force-define macros.
Expand All @@ -1571,10 +1670,7 @@ def generate_testcase(
rows, cols = _parse_shape(kernel_info["call_text"])
logical_elem_count = rows * cols
kernel_name = kernel_info["kernel_name"]
raw_params = kernel_info["raw_params"]
mrgsort_block_len = _infer_mrgsort_block_len(raw_kernel_for_analysis) if "TMRGSORT" in raw_kernel_for_analysis else None

pointer_param_names = [_extract_cpp_name(p) for p in raw_params if _is_gm_pointer_param(p)]
inferred_void_ptr_types = {}
for raw in raw_params:
if not _is_gm_pointer_param(raw):
Expand All @@ -1587,17 +1683,21 @@ def generate_testcase(
inferred_void_ptr_types[name] = inferred

ffts_param_names = _detect_set_ffts_pointer_params(raw_kernel_for_analysis, pointer_param_names)
non_ffts_pointer_param_names = [n for n in pointer_param_names if n not in ffts_param_names]
non_runtime_pointer_param_names = [
n
for n in pointer_param_names
if n not in ffts_param_names and n not in prefetch_workspace_param_names
]

output_param_names = []
for writer_text in kernel_info["writer_texts"]:
output_param_names.extend(_detect_output_pointer_params(writer_text, non_ffts_pointer_param_names))
output_param_names.extend(_detect_output_pointer_params(writer_text, non_runtime_pointer_param_names))
output_param_names = _ordered_unique(output_param_names)
if not output_param_names and non_ffts_pointer_param_names:
if not output_param_names and non_runtime_pointer_param_names:
output_param_names = [
non_ffts_pointer_param_names[0]
if len(non_ffts_pointer_param_names) == 1
else non_ffts_pointer_param_names[-1]
non_runtime_pointer_param_names[0]
if len(non_runtime_pointer_param_names) == 1
else non_runtime_pointer_param_names[-1]
]
output_param_name_set = set(output_param_names)

Expand All @@ -1618,7 +1718,11 @@ def generate_testcase(
"role": (
"ffts"
if name in ffts_param_names
else ("output" if name in output_param_name_set else "input")
else (
"prefetch_workspace"
if name in prefetch_workspace_param_names
else ("output" if name in output_param_name_set else "input")
)
),
}
)
Expand All @@ -1639,8 +1743,11 @@ def generate_testcase(
# - Some kernels are in-place (single pointer param) or may read from an
# "output" pointer as scratch. Leaving buffers uninitialized leads to
# non-determinism between CPU golden and real NPU.
data_ptrs = [p for p in params if p["kind"] == "ptr" and p["role"] != "ffts"]
data_ptrs = [p for p in params if p["kind"] == "ptr" and p["role"] not in {"ffts", "prefetch_workspace"}]
ffts_ptrs = [p for p in params if p["kind"] == "ptr" and p["role"] == "ffts"]
prefetch_workspace_ptrs = [
p for p in params if p["kind"] == "ptr" and p["role"] == "prefetch_workspace"
]
init_ptrs = list(data_ptrs)
output_ptrs = [p for p in data_ptrs if p["role"] == "output"]

Expand Down Expand Up @@ -1757,6 +1864,8 @@ def generate_testcase(
param_decls_lines.append(f" {p['host_type']} *{p['name']}Device = nullptr;")
param_decls_lines.append(f" uint64_t {p['name']}FftsAddr = 0;")
param_decls_lines.append(f" uint32_t {p['name']}FftsLen = 0;")
elif p["role"] == "prefetch_workspace":
param_decls_lines.append(f" {p['host_type']} *{p['name']}Device = nullptr;")
else:
param_decls_lines.append(f" {p['host_type']} *{p['name']}Host = nullptr;")
param_decls_lines.append(f" {p['host_type']} *{p['name']}Device = nullptr;")
Expand Down Expand Up @@ -1789,6 +1898,25 @@ def generate_testcase(
init_runtime_ptrs.append(
f" {p['name']}Device = reinterpret_cast<{p['host_type']} *>({p['name']}FftsAddr);"
)
if prefetch_workspace_ptrs:
param_decls_lines.append(" pto::comm::sdma::SdmaWorkspaceManager sdmaMgr;")
init_runtime_ptrs.append(" if (!sdmaMgr.Init()) {")
init_runtime_ptrs.append(' std::fprintf(stderr, "[ERROR] SdmaWorkspaceManager::Init failed\\n");')
init_runtime_ptrs.append(" rc = 1;")
init_runtime_ptrs.append(" goto cleanup;")
init_runtime_ptrs.append(" }")
for p in prefetch_workspace_ptrs:
init_runtime_ptrs.append(
f" {p['name']}Device = reinterpret_cast<{p['host_type']} *>(sdmaMgr.GetWorkspaceAddr());"
)
init_runtime_ptrs.append(f" if ({p['name']}Device == nullptr) {{")
init_runtime_ptrs.append(
f' std::fprintf(stderr, "[ERROR] SDMA workspace address is null for {p["name"]}\\n");'
)
init_runtime_ptrs.append(" rc = 1;")
init_runtime_ptrs.append(" goto cleanup;")
init_runtime_ptrs.append(" }")
free_device.append(" sdmaMgr.Finalize();")

read_inputs = []
copy_inputs = []
Expand Down Expand Up @@ -1824,6 +1952,12 @@ def generate_testcase(
# header here instead of `runtime/rt.h` to avoid environment-specific
# include path issues on some board images.
runtime_rt_include = '#include <stdint.h>\n#include <ccelib/common/runtime.h>'
if prefetch_workspace_ptrs:
runtime_rt_include = (
runtime_rt_include + '\n#include "pto/npu/comm/async/sdma/sdma_workspace_manager.hpp"'
if runtime_rt_include
else '#include "pto/npu/comm/async/sdma/sdma_workspace_manager.hpp"'
)
main_cpp = (
template
.replace("@RUNTIME_RT_INCLUDE@", runtime_rt_include)
Expand Down Expand Up @@ -2107,6 +2241,8 @@ def generate_testcase(
sv = (soc_version or "").lower()
if "910b" in sv or "950" in sv or "a5" in sv:
mem_base_define = "REGISTER_BASE"
if uses_prefetch_async_runtime:
mem_base_define = "MEMORY_BASE"

# CCE printing support is gated behind `--cce-enable-print` on some bisheng
# toolchains. Only enable it when kernels emit printf.
Expand Down
15 changes: 6 additions & 9 deletions test/samples/SyncAll/syncall_binding.pto
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,12 @@ module {
func.func @syncall_binding_kernel(%arg0: !pto.ptr<i32>, %arg1: i32) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c64 = arith.constant 64 : index
%0 = pto.make_tensor_view %arg0, shape = [%c64], strides = [%c1] : !pto.tensor_view<64xi32>
%1 = pto.partition_view %0, offsets = [%c0], sizes = [%c64] : !pto.tensor_view<64xi32>
%2 = pto.alloc_tile : !pto.tile_buf<vec, 1x64xi32>
%3 = pto.alloc_tile : !pto.tile_buf<mat, 1x64xi32>
pto.syncall(%1, %2, %arg1 : !pto.partition_tensor_view<64xi32>, !pto.tile_buf<vec, 1x64xi32>, i32) mode = #pto.sync_all_mode<soft>, core_type = #pto.sync_core_type<aiv_only>
pto.syncall(%1, %2, %3, %arg1 : !pto.partition_tensor_view<64xi32>, !pto.tile_buf<vec, 1x64xi32>, !pto.tile_buf<mat, 1x64xi32>, i32) mode = #pto.sync_all_mode<soft>, core_type = #pto.sync_core_type<mix>
pto.syncall() mode = #pto.sync_all_mode<hard>, core_type = #pto.sync_core_type<mix>
%c384 = arith.constant 384 : index
%c12288_i64 = arith.constant 12288 : i64
%0 = pto.make_tensor_view %arg0, shape = [%c384], strides = [%c1] : !pto.tensor_view<384xi32>
%1 = pto.partition_view %0, offsets = [%c0], sizes = [%c384] : !pto.tensor_view<384xi32>
%2 = pto.alloc_tile addr = %c12288_i64 : !pto.tile_buf<vec, 1x64xi32>
pto.syncall(%1, %2, %arg1 : !pto.partition_tensor_view<384xi32>, !pto.tile_buf<vec, 1x64xi32>, i32) mode = #pto.sync_all_mode<soft>, core_type = #pto.sync_core_type<aiv_only>
return
}
}

28 changes: 9 additions & 19 deletions test/samples/SyncAll/syncall_binding.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,19 @@ def build():
module = Module.create()

i32 = IntegerType.get_signless(32, ctx)
i64 = IntegerType.get_signless(64, ctx)
idx = IndexType.get(ctx)
ptr_i32 = pto.PtrType.get(i32, ctx)
tv_i32 = pto.TensorViewType.get([64], i32, ctx)
pv_i32 = pto.PartitionTensorViewType.get([64], i32, ctx)
workspace_elems = 48 * 8
tv_i32 = pto.TensorViewType.get([workspace_elems], i32, ctx)
pv_i32 = pto.PartitionTensorViewType.get([workspace_elems], i32, ctx)

vec = pto.AddressSpaceAttr.get(pto.AddressSpace.VEC, ctx)
mat = pto.AddressSpaceAttr.get(pto.AddressSpace.MAT, ctx)
bl = pto.BLayoutAttr.get(pto.BLayout.RowMajor, ctx)
sl = pto.SLayoutAttr.get(pto.SLayout.NoneBox, ctx)
pd = pto.PadValueAttr.get(pto.PadValue.Null, ctx)
cfg = pto.TileBufConfigAttr.get(bl, sl, pto.TileConfig.fractalABSize, pd, ctx)
ub_i32 = pto.TileBufType.get([1, 64], i32, vec, [1, 64], cfg, ctx)
l1_i32 = pto.TileBufType.get([1, 64], i32, mat, [1, 64], cfg, ctx)

fn_ty = func.FunctionType.get([ptr_i32, i32], [])
with InsertionPoint(module.body):
Expand All @@ -49,31 +49,21 @@ def build():
gm_workspace_ptr, used_cores = entry.arguments
c0 = arith.ConstantOp(idx, 0).result
c1 = arith.ConstantOp(idx, 1).result
c64 = arith.ConstantOp(idx, 64).result
c384 = arith.ConstantOp(idx, workspace_elems).result
c0x3000 = arith.ConstantOp(i64, 0x3000).result

gm_view = pto.MakeTensorViewOp(tv_i32, gm_workspace_ptr, [c64], [c1]).result
gm_view = pto.MakeTensorViewOp(tv_i32, gm_workspace_ptr, [c384], [c1]).result
gm_workspace = pto.PartitionViewOp(
pv_i32, gm_view, offsets=[c0], sizes=[c64]
pv_i32, gm_view, offsets=[c0], sizes=[c384]
).result
ub_workspace = pto.AllocTileOp(ub_i32).result
l1_workspace = pto.AllocTileOp(l1_i32).result

ub_workspace = pto.AllocTileOp(ub_i32, addr=c0x3000).result
pto.syncall(
_mode("soft"),
_core_type("aiv_only"),
gm_workspace=gm_workspace,
ub_workspace=ub_workspace,
used_cores=used_cores,
)
pto.syncall(
_mode("soft"),
_core_type("mix"),
gm_workspace=gm_workspace,
ub_workspace=ub_workspace,
l1_workspace=l1_workspace,
used_cores=used_cores,
)
pto.syncall(_mode("hard"), _core_type("mix"))
func.ReturnOp([])

module.operation.verify()
Expand Down
18 changes: 11 additions & 7 deletions test/samples/TPrefetchAsync/tprefetch_async_binding.pto
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
module {
func.func @tprefetch_async_binding_kernel(%arg0: !pto.ptr<f32>, %arg1: !pto.ptr<i8>) {
func.func @tprefetch_async_binding_kernel(%arg0: !pto.ptr<f32>, %arg1: !pto.ptr<f32>, %arg2: !pto.ptr<i8>) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c128 = arith.constant 128 : index
%0 = pto.make_tensor_view %arg0, shape = [%c128], strides = [%c1] : !pto.tensor_view<128xf32>
%1 = pto.partition_view %0, offsets = [%c0], sizes = [%c128] : !pto.tensor_view<128xf32>
%2 = pto.make_prefetch_async_context(%arg1 : !pto.ptr<i8>) -> !pto.prefetch_async_context
%3 = pto.tprefetch_async(%1, %2 : !pto.partition_tensor_view<128xf32>, !pto.prefetch_async_context) -> !pto.async_event
%4 = pto.get_prefetch_async_session %2 : !pto.prefetch_async_context -> !pto.async_session
%5 = pto.comm.wait_async_event(%3, %4 : !pto.async_event, !pto.async_session) -> i1
%1 = pto.make_tensor_view %arg1, shape = [%c128], strides = [%c1] : !pto.tensor_view<128xf32>
%2 = pto.partition_view %0, offsets = [%c0], sizes = [%c128] : !pto.tensor_view<128xf32>
%3 = pto.partition_view %1, offsets = [%c0], sizes = [%c128] : !pto.tensor_view<128xf32>
%4 = pto.make_prefetch_async_context(%arg2 : !pto.ptr<i8>) -> !pto.prefetch_async_context
%5 = pto.tprefetch_async(%2, %4 : !pto.partition_tensor_view<128xf32>, !pto.prefetch_async_context) -> !pto.async_event
%6 = pto.get_prefetch_async_session %4 : !pto.prefetch_async_context -> !pto.async_session
%7 = pto.comm.wait_async_event(%5, %6 : !pto.async_event, !pto.async_session) -> i1
%8 = pto.alloc_tile : !pto.tile_buf<vec, 1x128xf32>
pto.tload ins(%2 : !pto.partition_tensor_view<128xf32>) outs(%8 : !pto.tile_buf<vec, 1x128xf32>)
pto.tstore ins(%8 : !pto.tile_buf<vec, 1x128xf32>) outs(%3 : !pto.partition_tensor_view<128xf32>)
return
}
}

15 changes: 13 additions & 2 deletions test/samples/TPrefetchAsync/tprefetch_async_binding.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,25 +24,36 @@ def build():
ptr_i8 = pto.PtrType.get(i8, ctx)
tv_f32 = pto.TensorViewType.get([128], f32, ctx)
pv_f32 = pto.PartitionTensorViewType.get([128], f32, ctx)
vec = pto.AddressSpaceAttr.get(pto.AddressSpace.VEC, ctx)
bl = pto.BLayoutAttr.get(pto.BLayout.RowMajor, ctx)
sl = pto.SLayoutAttr.get(pto.SLayout.NoneBox, ctx)
pd = pto.PadValueAttr.get(pto.PadValue.Null, ctx)
cfg = pto.TileBufConfigAttr.get(bl, sl, pto.TileConfig.fractalABSize, pd, ctx)
tile_f32 = pto.TileBufType.get([1, 128], f32, vec, [1, 128], cfg, ctx)

fn_ty = func.FunctionType.get([ptr_f32, ptr_i8], [])
fn_ty = func.FunctionType.get([ptr_f32, ptr_f32, ptr_i8], [])
with InsertionPoint(module.body):
fn = func.FuncOp("tprefetch_async_binding_kernel", fn_ty)
entry = fn.add_entry_block()

with InsertionPoint(entry):
src_ptr, workspace_ptr = entry.arguments
src_ptr, dst_ptr, workspace_ptr = entry.arguments
c0 = arith.ConstantOp(idx, 0).result
c1 = arith.ConstantOp(idx, 1).result
c128 = arith.ConstantOp(idx, 128).result

src_view = pto.MakeTensorViewOp(tv_f32, src_ptr, [c128], [c1]).result
dst_view = pto.MakeTensorViewOp(tv_f32, dst_ptr, [c128], [c1]).result
src = pto.PartitionViewOp(pv_f32, src_view, offsets=[c0], sizes=[c128]).result
dst = pto.PartitionViewOp(pv_f32, dst_view, offsets=[c0], sizes=[c128]).result

prefetch_ctx = pto.MakePrefetchAsyncContextOp(workspace_ptr).result
event = pto.TPrefetchAsyncOp(src, prefetch_ctx).result
session = pto.GetPrefetchAsyncSessionOp(prefetch_ctx).result
pto.WaitAsyncEventOp(event, session)
tile = pto.AllocTileOp(tile_f32).result
pto.TLoadOp(None, src, tile)
pto.TStoreOp(None, tile, dst)
func.ReturnOp([])

module.operation.verify()
Expand Down
Loading
Loading