-
Notifications
You must be signed in to change notification settings - Fork 59
fix(ptobc): support dense executed constants in v0 #717
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -20,7 +20,9 @@ | |||||||||||||||||
| #include <mlir/Dialect/SCF/IR/SCF.h> | ||||||||||||||||||
|
|
||||||||||||||||||
| #include <PTO/IR/PTO.h> | ||||||||||||||||||
| #include <mlir/IR/BuiltinAttributes.h> | ||||||||||||||||||
| #include <mlir/IR/BuiltinOps.h> | ||||||||||||||||||
| #include <mlir/IR/BuiltinTypes.h> | ||||||||||||||||||
| #include <mlir/IR/Location.h> | ||||||||||||||||||
| #include <mlir/IR/Operation.h> | ||||||||||||||||||
| #include <mlir/IR/OpImplementation.h> | ||||||||||||||||||
|
|
@@ -127,6 +129,8 @@ struct ConstEntryParsed { | |||||||||||||||||
| std::vector<uint8_t> floatBytes; | ||||||||||||||||||
| // tag=0x04 wide int bits: type_id + bytes | ||||||||||||||||||
| std::vector<uint8_t> intBytes; | ||||||||||||||||||
| // tag=0x05 dense elements bits: type_id + packed element bytes | ||||||||||||||||||
| std::vector<uint8_t> denseBytes; | ||||||||||||||||||
| }; | ||||||||||||||||||
|
|
||||||||||||||||||
| struct DbgFileEntry { uint64_t pathSid; uint8_t hashKind; std::vector<uint8_t> hashBytes; }; | ||||||||||||||||||
|
|
@@ -236,7 +240,34 @@ static void parseAttrsSection(const std::vector<uint8_t>& data, | |||||||||||||||||
| if (r.p != r.end) throw std::runtime_error("trailing bytes in ATTRS"); | ||||||||||||||||||
| } | ||||||||||||||||||
|
|
||||||||||||||||||
| static std::vector<uint8_t> | ||||||||||||||||||
| readDenseElementBytes(Reader &r, mlir::MLIRContext &ctx, | ||||||||||||||||||
| const std::vector<TypeEntry> &types, uint64_t tid) { | ||||||||||||||||||
| if (tid >= types.size()) | ||||||||||||||||||
| throw std::runtime_error("bad type_id in dense const"); | ||||||||||||||||||
|
|
||||||||||||||||||
| mlir::Type type = parseType(ctx, types[tid].asmStr); | ||||||||||||||||||
| auto shapedType = mlir::dyn_cast<mlir::ShapedType>(type); | ||||||||||||||||||
| if (!shapedType || !shapedType.hasStaticShape()) | ||||||||||||||||||
| throw std::runtime_error("dense const requires static shaped type"); | ||||||||||||||||||
|
|
||||||||||||||||||
| mlir::Type elementType = shapedType.getElementType(); | ||||||||||||||||||
| unsigned bitWidth = 0; | ||||||||||||||||||
| if (auto intType = mlir::dyn_cast<mlir::IntegerType>(elementType)) | ||||||||||||||||||
| bitWidth = intType.getWidth(); | ||||||||||||||||||
| else if (auto floatType = mlir::dyn_cast<mlir::FloatType>(elementType)) | ||||||||||||||||||
| bitWidth = floatType.getWidth(); | ||||||||||||||||||
| else | ||||||||||||||||||
| throw std::runtime_error("dense const requires integer or float element type"); | ||||||||||||||||||
|
|
||||||||||||||||||
| uint64_t numElements = shapedType.getNumElements(); | ||||||||||||||||||
| uint64_t byteLen = (bitWidth + 7) / 8; | ||||||||||||||||||
| return r.readBytes(size_t(numElements * byteLen)); | ||||||||||||||||||
| } | ||||||||||||||||||
|
|
||||||||||||||||||
| static void parseConstPoolSection(const std::vector<uint8_t>& data, | ||||||||||||||||||
| mlir::MLIRContext &ctx, | ||||||||||||||||||
| const std::vector<TypeEntry> &types, | ||||||||||||||||||
| std::vector<ConstEntryParsed>& consts) { | ||||||||||||||||||
| Reader r{data.data(), data.data() + data.size()}; | ||||||||||||||||||
| uint64_t cnt = r.readULEB(); | ||||||||||||||||||
|
|
@@ -279,6 +310,14 @@ static void parseConstPoolSection(const std::vector<uint8_t>& data, | |||||||||||||||||
| e.typeId = tid; | ||||||||||||||||||
| e.intBytes = std::move(bytes); | ||||||||||||||||||
| consts.push_back(std::move(e)); | ||||||||||||||||||
| } else if (tag == 0x05) { | ||||||||||||||||||
| uint64_t tid = r.readULEB(); | ||||||||||||||||||
| auto bytes = readDenseElementBytes(r, ctx, types, tid); | ||||||||||||||||||
| ConstEntryParsed e; | ||||||||||||||||||
| e.tag = tag; | ||||||||||||||||||
| e.typeId = tid; | ||||||||||||||||||
| e.denseBytes = std::move(bytes); | ||||||||||||||||||
| consts.push_back(std::move(e)); | ||||||||||||||||||
| } else { | ||||||||||||||||||
| throw std::runtime_error("unknown ConstEntry tag"); | ||||||||||||||||||
| } | ||||||||||||||||||
|
|
@@ -360,6 +399,56 @@ static mlir::Attribute buildIntegerConstAttr(BuildCtx &bc, | |||||||||||||||||
| return mlir::IntegerAttr::get(intType, bits); | ||||||||||||||||||
| } | ||||||||||||||||||
|
|
||||||||||||||||||
| static mlir::Attribute buildDenseConstAttr(BuildCtx &bc, | ||||||||||||||||||
| const ConstEntryParsed &entry) { | ||||||||||||||||||
| auto type = getType(bc, entry.typeId); | ||||||||||||||||||
| auto shapedType = mlir::dyn_cast<mlir::ShapedType>(type); | ||||||||||||||||||
| if (!shapedType || !shapedType.hasStaticShape()) | ||||||||||||||||||
| throw std::runtime_error("ConstDenseBits type is not static shaped type"); | ||||||||||||||||||
|
|
||||||||||||||||||
| mlir::Type elementType = shapedType.getElementType(); | ||||||||||||||||||
| uint64_t numElements = shapedType.getNumElements(); | ||||||||||||||||||
| if (auto intType = mlir::dyn_cast<mlir::IntegerType>(elementType)) { | ||||||||||||||||||
| unsigned bitWidth = intType.getWidth(); | ||||||||||||||||||
| unsigned byteLen = (bitWidth + 7) / 8; | ||||||||||||||||||
| if (entry.denseBytes.size() != size_t(numElements) * byteLen) | ||||||||||||||||||
| throw std::runtime_error("ConstDenseBits integer byte_len mismatch"); | ||||||||||||||||||
|
Comment on lines
+413
to
+415
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Security Vulnerability: Integer Overflow leading to Out-of-Bounds ReadSimilarly to the decoding phase, we must guard against integer overflow when validating the byte length of integer dense constants during attribute reconstruction. If Add an overflow check before validating the buffer size.
Suggested change
|
||||||||||||||||||
|
|
||||||||||||||||||
| llvm::SmallVector<llvm::APInt, 8> values; | ||||||||||||||||||
| values.reserve(numElements); | ||||||||||||||||||
| for (uint64_t i = 0; i < numElements; ++i) { | ||||||||||||||||||
| size_t offset = size_t(i) * byteLen; | ||||||||||||||||||
| values.push_back(rebuildAPIntFromBytes( | ||||||||||||||||||
| llvm::ArrayRef<uint8_t>(entry.denseBytes.data() + offset, byteLen), | ||||||||||||||||||
| bitWidth)); | ||||||||||||||||||
| } | ||||||||||||||||||
| return mlir::DenseElementsAttr::get(shapedType, | ||||||||||||||||||
| llvm::ArrayRef<llvm::APInt>(values)); | ||||||||||||||||||
| } | ||||||||||||||||||
|
|
||||||||||||||||||
| if (auto floatType = mlir::dyn_cast<mlir::FloatType>(elementType)) { | ||||||||||||||||||
| unsigned bitWidth = floatType.getWidth(); | ||||||||||||||||||
| unsigned byteLen = (bitWidth + 7) / 8; | ||||||||||||||||||
| if (entry.denseBytes.size() != size_t(numElements) * byteLen) | ||||||||||||||||||
| throw std::runtime_error("ConstDenseBits float byte_len mismatch"); | ||||||||||||||||||
|
Comment on lines
+431
to
+433
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Security Vulnerability: Integer Overflow leading to Out-of-Bounds ReadSimilarly to the integer path, we must guard against integer overflow when validating the byte length of float dense constants during attribute reconstruction. Add an overflow check before validating the buffer size.
Suggested change
|
||||||||||||||||||
|
|
||||||||||||||||||
| llvm::SmallVector<llvm::APFloat, 8> values; | ||||||||||||||||||
| values.reserve(numElements); | ||||||||||||||||||
| for (uint64_t i = 0; i < numElements; ++i) { | ||||||||||||||||||
| size_t offset = size_t(i) * byteLen; | ||||||||||||||||||
| llvm::APInt bits = rebuildAPIntFromBytes( | ||||||||||||||||||
| llvm::ArrayRef<uint8_t>(entry.denseBytes.data() + offset, byteLen), | ||||||||||||||||||
| bitWidth); | ||||||||||||||||||
| values.emplace_back(floatType.getFloatSemantics(), bits); | ||||||||||||||||||
| } | ||||||||||||||||||
| return mlir::DenseElementsAttr::get(shapedType, | ||||||||||||||||||
| llvm::ArrayRef<llvm::APFloat>(values)); | ||||||||||||||||||
| } | ||||||||||||||||||
|
|
||||||||||||||||||
| throw std::runtime_error( | ||||||||||||||||||
| "ConstDenseBits element type is not integer or float"); | ||||||||||||||||||
| } | ||||||||||||||||||
|
|
||||||||||||||||||
| static mlir::Attribute buildConstAttr(BuildCtx &bc, uint64_t constId) { | ||||||||||||||||||
| if (!bc.consts) throw std::runtime_error("constpool not available"); | ||||||||||||||||||
| if (constId >= bc.consts->size()) throw std::runtime_error("const_id out of range"); | ||||||||||||||||||
|
|
@@ -381,6 +470,9 @@ static mlir::Attribute buildConstAttr(BuildCtx &bc, uint64_t constId) { | |||||||||||||||||
| if (e.tag == 0x04) | ||||||||||||||||||
| return buildIntegerConstAttr(bc, e); | ||||||||||||||||||
|
|
||||||||||||||||||
| if (e.tag == 0x05) | ||||||||||||||||||
| return buildDenseConstAttr(bc, e); | ||||||||||||||||||
|
|
||||||||||||||||||
| throw std::runtime_error("unsupported const tag"); | ||||||||||||||||||
| } | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
@@ -775,7 +867,7 @@ static mlir::ModuleOp decodeToModule(mlir::MLIRContext& ctx, | |||||||||||||||||
|
|
||||||||||||||||||
| Reader r{moduleBytes.data(), moduleBytes.data() + moduleBytes.size()}; | ||||||||||||||||||
| std::vector<ConstEntryParsed> consts; | ||||||||||||||||||
| parseConstPoolSection(constPool, consts); | ||||||||||||||||||
| parseConstPoolSection(constPool, ctx, types, consts); | ||||||||||||||||||
| BuildCtx bc{&ctx, &strings, &types, &attrs, &consts, {}, nullptr, nullptr}; | ||||||||||||||||||
| uint64_t moduleAttrId = readModuleHeader(r, dbg); | ||||||||||||||||||
| std::vector<FuncDecl> decls = readFunctionDecls(bc, r, dbg); | ||||||||||||||||||
|
|
||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,24 @@ | ||
| // Copyright (c) 2026 Huawei Technologies Co., Ltd. | ||
| // This program is free software, you can redistribute it and/or modify it under the terms and conditions of | ||
| // CANN Open Software License Agreement Version 2.0 (the "License"). | ||
| // Please refer to the License for details. You may not use this file except in compliance with the License. | ||
| // THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, | ||
| // INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. | ||
| // See LICENSE in the root of the software repository for the full text of the License. | ||
|
|
||
| module { | ||
| func.func @mrgsort_dense_const_v0() { | ||
| %src0 = pto.alloc_tile : !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=128, v_row=1, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0> | ||
| %src1 = pto.alloc_tile : !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=128, v_row=1, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0> | ||
| %tmp = pto.alloc_tile : !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=256, v_row=1, v_col=256, blayout=row_major, slayout=none_box, fractal=512, pad=0> | ||
| %dst = pto.alloc_tile : !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=256, v_row=1, v_col=256, blayout=row_major, slayout=none_box, fractal=512, pad=0> | ||
| %executed = arith.constant dense<0> : vector<4xi16> | ||
| pto.tmrgsort ins(%src0, %src1, %tmp {exhausted = false} : !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=128, v_row=1, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>, !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=128, v_row=1, v_col=128, blayout=row_major, slayout=none_box, fractal=512, pad=0>, !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=256, v_row=1, v_col=256, blayout=row_major, slayout=none_box, fractal=512, pad=0>) outs(%dst, %executed : !pto.tile_buf<loc=vec, dtype=f32, rows=1, cols=256, v_row=1, v_col=256, blayout=row_major, slayout=none_box, fractal=512, pad=0>, vector<4xi16>) | ||
|
|
||
| %reduce_src = pto.alloc_tile : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=32, v_row=16, v_col=32, blayout=row_major, slayout=none_box, fractal=512, pad=0> | ||
| %reduce_tmp = pto.alloc_tile : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=32, v_row=16, v_col=32, blayout=row_major, slayout=none_box, fractal=512, pad=0> | ||
| %reduce_dst = pto.alloc_tile : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=1, v_row=16, v_col=1, blayout=col_major, slayout=none_box, fractal=512, pad=0> | ||
| pto.trowprod ins(%reduce_src, %reduce_tmp : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=32, v_row=16, v_col=32, blayout=row_major, slayout=none_box, fractal=512, pad=0>, !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=32, v_row=16, v_col=32, blayout=row_major, slayout=none_box, fractal=512, pad=0>) outs(%reduce_dst : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=1, v_row=16, v_col=1, blayout=col_major, slayout=none_box, fractal=512, pad=0>) | ||
| return | ||
| } | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,36 @@ | ||
| #!/usr/bin/env bash | ||
| # Copyright (c) 2026 Huawei Technologies Co., Ltd. | ||
| # This program is free software, you can redistribute it and/or modify it under the terms and conditions of | ||
| # CANN Open Software License Agreement Version 2.0 (the "License"). | ||
| # Please refer to the License for details. You may not use this file except in compliance with the License. | ||
| # THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, | ||
| # INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. | ||
| # See LICENSE in the root of the software repository for the full text of the License. | ||
|
|
||
| set -euo pipefail | ||
|
|
||
| PTOBC_BIN=${PTOBC_BIN:-} | ||
| if [[ -z "${PTOBC_BIN}" ]]; then | ||
| echo "error: PTOBC_BIN not set" >&2 | ||
| exit 2 | ||
| fi | ||
|
|
||
| TESTDATA_DIR=${TESTDATA_DIR:-} | ||
| if [[ -z "${TESTDATA_DIR}" ]]; then | ||
| echo "error: TESTDATA_DIR not set" >&2 | ||
| exit 2 | ||
| fi | ||
|
|
||
| IN="${TESTDATA_DIR}/mrgsort_dense_const_v0_roundtrip.pto" | ||
| OUT_DIR=${OUT_DIR:-"${PWD}/ptobc_mrgsort_dense_const_out"} | ||
| mkdir -p "${OUT_DIR}" | ||
|
|
||
| BC="${OUT_DIR}/mrgsort_dense_const_v0_roundtrip.ptobc" | ||
| ROUNDTRIP="${OUT_DIR}/mrgsort_dense_const_v0_roundtrip.roundtrip.pto" | ||
|
|
||
| "${PTOBC_BIN}" encode "${IN}" -o "${BC}" | ||
| "${PTOBC_BIN}" decode "${BC}" -o "${ROUNDTRIP}" | ||
|
|
||
| grep -F "pto.tmrgsort ins(" "${ROUNDTRIP}" >/dev/null | ||
| grep -E "arith.constant dense<(\[0, 0, 0, 0\]|0)> : vector<4xi16>" "${ROUNDTRIP}" >/dev/null | ||
| grep -F "pto.trowprod ins(" "${ROUNDTRIP}" >/dev/null |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Security Vulnerability: Integer Overflow leading to Out-of-Bounds Read
When decoding a dense constant,
numElementsandbyteLenare multiplied to determine the total number of bytes to read. If a maliciously crafted.ptobcfile specifies an extremely large shape, this multiplication can overflowsize_t. This causesr.readBytesto allocate and read a much smaller buffer than expected, which subsequently leads to out-of-bounds heap reads and potential crashes during attribute reconstruction.To prevent this, we must check for integer overflow before performing the multiplication.