Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/PTO/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ createPlanMemoryPass(const PlanMemoryOptions &planMemoryOption = {});
std::unique_ptr<Pass> createPTORemoveRedundantBarrierPass();
std::unique_ptr<Pass> createPTOViewToMemrefPass();
std::unique_ptr<Pass> createInferPTOLayoutPass();
std::unique_ptr<Pass> createPTOA5NormalizeTMovPass();
// Declare register function
void registerPTOPasses();

Expand Down
13 changes: 13 additions & 0 deletions include/PTO/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,19 @@ def InferPTOLayout : Pass<"pto-infer-layout", "func::FuncOp"> {
let dependentDialects = ["pto::PTODialect", "arith::ArithDialect"];
}

def PTOA5NormalizeTMov : Pass<"pto-a5-normalize-tmov", "func::FuncOp"> {
let summary = "Normalize risky A5 vec->vec col_major TMOV into row-major reinterpret + TMOV";
let description = [{
Rewrites A5 `pto.tmov` patterns that use vec->vec col_major/none_box tiles:
src(col_major) -> pto.treshape(row_major) -> pto.tmov(row_major->row_major)
dst(col_major) -> pto.treshape(row_major)
This avoids unsupported A5 PTO-ISA vec->vec col_major TMOV paths while
preserving alias semantics via SSA treshape (no real data movement).
}];
let constructor = "mlir::pto::createPTOA5NormalizeTMovPass()";
let dependentDialects = ["pto::PTODialect", "func::FuncDialect"];
}


def InferPTOMemScope : Pass<"pto-infer-mem-scope"> {
let summary = "Infer memory scope for PTO Ops";
Expand Down
1 change: 1 addition & 0 deletions lib/PTO/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ add_mlir_dialect_library(PTOTransforms
PTOPlanMemory.cpp
PTORemoveRedundantBarrier.cpp
InferPTOLayout.cpp
PTOA5NormalizeTMovPass.cpp
BufferizableOpInterfaceImpl.cpp
ConvertToPTOOp.cpp
PTOLowerFrontendPipeOpsPass.cpp
Expand Down
176 changes: 176 additions & 0 deletions lib/PTO/Transforms/PTOA5NormalizeTMovPass.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
// Copyright (c) 2026 Huawei Technologies Co., Ltd.
// This program is free software, you can redistribute it and/or modify it under the terms and conditions of
// CANN Open Software License Agreement Version 2.0 (the "License").
// Please refer to the License for details. You may not use this file except in compliance with the License.
// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
// See LICENSE in the root of the software repository for the full text of the License.

#include "PTO/IR/PTO.h"
#include "PTO/Transforms/Passes.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"

namespace mlir {
namespace pto {
#define GEN_PASS_DEF_PTOA5NORMALIZETMOV
#include "PTO/Transforms/Passes.h.inc"
} // namespace pto
} // namespace mlir

using namespace mlir;
using namespace mlir::pto;

namespace {

static bool isVecTileType(pto::TileBufType type) {
auto asAttr = dyn_cast_or_null<pto::AddressSpaceAttr>(type.getMemorySpace());
return asAttr && asAttr.getAddressSpace() == pto::AddressSpace::VEC;
}

static bool isColMajorNoneBox(pto::TileBufType type) {
return type.getBLayoutValueI32() == static_cast<int32_t>(pto::BLayout::ColMajor) &&
type.getSLayoutValueI32() == static_cast<int32_t>(pto::SLayout::NoneBox);
}

static bool isA5RiskyVecVecColMajorTMov(pto::TMovOp op) {
auto srcTb = dyn_cast<pto::TileBufType>(op.getSrc().getType());
auto dstTb = dyn_cast<pto::TileBufType>(op.getDst().getType());
if (!srcTb || !dstTb)
return false;
if (!isVecTileType(srcTb) || !isVecTileType(dstTb))
return false;
return isColMajorNoneBox(srcTb) && isColMajorNoneBox(dstTb);
}

template <typename CfgT>
static auto buildRowMajorConfigImpl(int, MLIRContext *ctx,
pto::BLayoutAttr rowMajor, CfgT cfg)
-> decltype(pto::TileBufConfigAttr::get(ctx, rowMajor, cfg.getSLayout(),
cfg.getSFractalSize(), cfg.getPad(),
cfg.getCompactMode())) {
return pto::TileBufConfigAttr::get(ctx, rowMajor, cfg.getSLayout(),
cfg.getSFractalSize(), cfg.getPad(),
cfg.getCompactMode());
}

template <typename CfgT>
static auto buildRowMajorConfigImpl(long, MLIRContext *ctx,
pto::BLayoutAttr rowMajor, CfgT cfg)
-> decltype(pto::TileBufConfigAttr::get(ctx, rowMajor, cfg.getSLayout(),
cfg.getSFractalSize(),
cfg.getPad())) {
return pto::TileBufConfigAttr::get(ctx, rowMajor, cfg.getSLayout(),
cfg.getSFractalSize(), cfg.getPad());
}

static pto::TileBufConfigAttr buildRowMajorConfig(MLIRContext *ctx,
pto::TileBufConfigAttr cfg) {
auto rowMajor = pto::BLayoutAttr::get(ctx, pto::BLayout::RowMajor);
return buildRowMajorConfigImpl(0, ctx, rowMajor, cfg);
}

static FailureOr<pto::TileBufType>
buildRowMajorReinterpretType(MLIRContext *ctx, pto::TileBufType srcType) {
ArrayRef<int64_t> shape = srcType.getShape();
if (shape.size() != 2)
return failure();
if (shape[0] == ShapedType::kDynamic || shape[1] == ShapedType::kDynamic)
return failure();

SmallVector<int64_t, 2> swappedShape{shape[1], shape[0]};

SmallVector<int64_t, 2> swappedValid;
ArrayRef<int64_t> validShape = srcType.getValidShape();
if (validShape.empty()) {
swappedValid = swappedShape;
} else if (validShape.size() == 2) {
swappedValid.assign({validShape[1], validShape[0]});
} else {
return failure();
}

auto cfg = srcType.getConfigAttr();
if (!cfg)
cfg = pto::TileBufConfigAttr::getDefault(ctx);
auto newCfg = buildRowMajorConfig(ctx, cfg);

return pto::TileBufType::get(ctx, swappedShape, srcType.getElementType(),
srcType.getMemorySpace(), swappedValid, newCfg);
}

struct PTOA5NormalizeTMovPass
: public mlir::pto::impl::PTOA5NormalizeTMovBase<PTOA5NormalizeTMovPass> {
void runOnOperation() override {
func::FuncOp func = getOperation();
if (!isTargetArchA5(func.getOperation()))
return;

SmallVector<pto::TMovOp, 8> riskyOps;
func.walk([&](pto::TMovOp op) {
if (isA5RiskyVecVecColMajorTMov(op))
riskyOps.push_back(op);
});

IRRewriter rewriter(func.getContext());
for (pto::TMovOp op : riskyOps) {
auto srcTb = cast<pto::TileBufType>(op.getSrc().getType());
auto dstTb = cast<pto::TileBufType>(op.getDst().getType());

FailureOr<pto::TileBufType> srcRowTy =
buildRowMajorReinterpretType(func.getContext(), srcTb);
FailureOr<pto::TileBufType> dstRowTy =
buildRowMajorReinterpretType(func.getContext(), dstTb);
if (failed(srcRowTy) || failed(dstRowTy)) {
op.emitOpError(
"cannot normalize A5 vec->vec col_major TMOV: requires static 2D "
"tile_buf shape/valid_shape for treshape reinterpret");
signalPassFailure();
return;
}

rewriter.setInsertionPoint(op);
auto srcRow =
rewriter.create<pto::TReshapeOp>(op.getLoc(), *srcRowTy, op.getSrc());
auto dstRow =
rewriter.create<pto::TReshapeOp>(op.getLoc(), *dstRowTy, op.getDst());
SmallVector<Value, 4> newOperands(op->operand_begin(), op->operand_end());
if (newOperands.size() < 2) {
op.emitOpError("unexpected operand count while normalizing TMOV");
signalPassFailure();
return;
}
newOperands[0] = srcRow.getResult();
newOperands[1] = dstRow.getResult();

OperationState state(op.getLoc(), pto::TMovOp::getOperationName());
state.addOperands(newOperands);
state.addTypes(op->getResultTypes());
state.addAttributes(op->getAttrs());
auto *created = rewriter.create(state);
auto newTmov = cast<pto::TMovOp>(created);
(void)newTmov;
rewriter.eraseOp(op);
}

bool hasResidualRisk = false;
func.walk([&](pto::TMovOp op) {
if (!isA5RiskyVecVecColMajorTMov(op))
return WalkResult::advance();
op.emitOpError(
"A5 vec->vec TMOV on col_major/none_box tile is unsupported; "
"expected normalization to row_major via pto.treshape");
hasResidualRisk = true;
return WalkResult::interrupt();
});
if (hasResidualRisk)
signalPassFailure();
}
};

} // namespace

std::unique_ptr<Pass> mlir::pto::createPTOA5NormalizeTMovPass() {
return std::make_unique<PTOA5NormalizeTMovPass>();
}
37 changes: 36 additions & 1 deletion test/npu_validation/scripts/generate_testcase.py
Original file line number Diff line number Diff line change
Expand Up @@ -1111,6 +1111,23 @@ def generate_testcase(
for p in data_ptrs:
inferred = inferred_counts.get(p["name"])
ptr_elem_counts[p["name"]] = int(inferred) if inferred and int(inferred) > 0 else logical_elem_count
if testcase in {"rmsnorm_incore_0", "decode_projection_incore_0"}:
# These repro kernels partition a [16, hidden] ND view with a row
# offset. Board validation runs a single-block case, so keep bf16
# input/output buffers large enough for the full 16xhidden window.
required_elems = 16 * (5120 if testcase == "rmsnorm_incore_0" else 8192)
for p in data_ptrs:
if p["host_type"] != "uint16_t":
continue
cur = int(ptr_elem_counts.get(p["name"], logical_elem_count))
ptr_elem_counts[p["name"]] = max(cur, required_elems)
if testcase == "decode_projection_incore_0":
# decode_projection_incore_0 also reads gamma as f32[1, 8192].
for p in data_ptrs:
if p["host_type"] != "float":
continue
cur = int(ptr_elem_counts.get(p["name"], logical_elem_count))
ptr_elem_counts[p["name"]] = max(cur, 8192)

templates_root = Path(__file__).resolve().parents[1] / "templates"
template = (templates_root / "main_template.cpp").read_text(encoding="utf-8")
Expand Down Expand Up @@ -1141,6 +1158,24 @@ def generate_testcase(
if p["kind"] != "scalar":
continue
t = p["host_type"]
if testcase in {"rmsnorm_incore_0", "decode_projection_incore_0"} and t in {
"int8_t",
"uint8_t",
"int16_t",
"uint16_t",
"int32_t",
"uint32_t",
"int64_t",
"uint64_t",
"int",
"unsigned",
"size_t",
}:
# These kernels use this scalar as row offset (%arg3).
# Keep it at 0 for single-block validation to avoid shifted windows.
value = "0"
param_decls_lines.append(f" {t} {p['name']} = {value};")
continue
# Some PTO-ISA APIs use small POD structs as scalar parameters.
# Example: pto::MrgSortExecutedNumList (used by TMRGSORT multi-list variants).
if t.endswith("MrgSortExecutedNumList"):
Expand Down Expand Up @@ -1455,7 +1490,7 @@ def generate_testcase(
mem_base_define = "REGISTER_BASE"

# CCE printing support is gated behind `--cce-enable-print` on some bisheng
# toolchains. Only enable it for kernels that actually emit printf.
# toolchains. Only enable it when kernels emit printf.
needs_cce_print = bool(re.search(r"\b(?:bisheng::)?cce::printf\s*\(", raw_kernel_for_analysis))
cce_enable_print_opt = " --cce-enable-print" if needs_cce_print else ""
cce_print_define_opt = " -DPTOAS_ENABLE_CCE_PRINT=1" if needs_cce_print else ""
Expand Down
Loading
Loading