Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions examples/a2a3/host_build_graph/paged_attention/golden.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
# See LICENSE in the root of the software repository for the full text of the License.
# -----------------------------------------------------------------------------------------------------------
"""Paged Attention Golden - host_build_graph example (small scale, float16).
"""Paged Attention Golden - host_build_graph example (small scale, bfloat16).

Args layout: [query, key_cache, value_cache, block_table, context_lens, out, scale]
- Tensors retain original multi-dimensional shapes (ContinuousTensor metadata carries shape/dtype)
Expand All @@ -33,7 +33,7 @@
"block_size": 16,
"context_len": 16,
"max_model_len": 256,
"dtype": "float16",
"dtype": "bfloat16",
},
"Case2": {
"batch": 1,
Expand All @@ -43,7 +43,7 @@
"block_size": 16,
"context_len": 64,
"max_model_len": 256,
"dtype": "float16",
"dtype": "bfloat16",
},
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
//
// Fixed tile size: (16, 16) @ (16, 16) -> (16, 16)
//
// pij is float16 (converted from fp32 in softmax_prepare via TCVT).
// pij is bfloat16 (converted from fp32 in softmax_prepare via TCVT).
// vj is stored as (K, N) = (block_size, head_dim) in row-major (ND) layout.
// Standard non-transposed B pattern: ND GlobalB + ColMajor/RowMajor TileMatB.

Expand All @@ -32,26 +32,26 @@ using namespace pto;
static __aicore__ void pv_matmul_impl(__gm__ uint8_t *pij_raw, __gm__ uint8_t *vj_raw, __gm__ uint8_t *oi_raw) {
constexpr int M = 16, K = 16, N = 16;

__gm__ half *pij = reinterpret_cast<__gm__ half *>(pij_raw);
__gm__ half *vj = reinterpret_cast<__gm__ half *>(vj_raw);
__gm__ bfloat16_t *pij = reinterpret_cast<__gm__ bfloat16_t *>(pij_raw);
__gm__ bfloat16_t *vj = reinterpret_cast<__gm__ bfloat16_t *>(vj_raw);
__gm__ float *oi = reinterpret_cast<__gm__ float *>(oi_raw);

// pij (M, K) fp16, vj (K, N) fp16 in ND (row-major), oi_new (M, N) fp32
using GlobalA = GlobalTensor<half, Shape<1, 1, 1, M, K>, Stride<M * K, M * K, M * K, K, 1>>;
using GlobalB = GlobalTensor<half, Shape<1, 1, 1, K, N>, Stride<K * N, K * N, K * N, N, 1>>;
// pij (M, K) bf16, vj (K, N) bf16 in ND (row-major), oi_new (M, N) fp32
using GlobalA = GlobalTensor<bfloat16_t, Shape<1, 1, 1, M, K>, Stride<M * K, M * K, M * K, K, 1>>;
using GlobalB = GlobalTensor<bfloat16_t, Shape<1, 1, 1, K, N>, Stride<K * N, K * N, K * N, N, 1>>;
using GlobalOut = GlobalTensor<float, Shape<1, 1, 1, M, N>, Stride<M * N, M * N, M * N, N, 1>>;

GlobalA pijGlobal(pij);
GlobalB vjGlobal(vj);
GlobalOut oiGlobal(oi);

// L1 Mat tiles: standard ND pattern for both A and B
using TileMatA = Tile<TileType::Mat, half, M, K, BLayout::ColMajor, M, K, SLayout::RowMajor, 512>;
using TileMatB = Tile<TileType::Mat, half, K, N, BLayout::ColMajor, K, N, SLayout::RowMajor, 512>;
using TileMatA = Tile<TileType::Mat, bfloat16_t, M, K, BLayout::ColMajor, M, K, SLayout::RowMajor, 512>;
using TileMatB = Tile<TileType::Mat, bfloat16_t, K, N, BLayout::ColMajor, K, N, SLayout::RowMajor, 512>;

// L0 tiles
using LeftTile = TileLeft<half, M, K, M, K>;
using RightTile = TileRight<half, K, N, K, N>;
using LeftTile = TileLeft<bfloat16_t, M, K, M, K>;
using RightTile = TileRight<bfloat16_t, K, N, K, N>;
using AccTile = TileAcc<float, M, N, M, N>;

TileMatA aMatTile;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,27 +32,27 @@ using namespace pto;
static __aicore__ void qk_matmul_impl(__gm__ uint8_t *qi_raw, __gm__ uint8_t *kj_raw, __gm__ uint8_t *sij_raw) {
constexpr int M = 16, K = 16, N = 16;

__gm__ half *qi = reinterpret_cast<__gm__ half *>(qi_raw);
__gm__ half *kj = reinterpret_cast<__gm__ half *>(kj_raw);
__gm__ bfloat16_t *qi = reinterpret_cast<__gm__ bfloat16_t *>(qi_raw);
__gm__ bfloat16_t *kj = reinterpret_cast<__gm__ bfloat16_t *>(kj_raw);
__gm__ float *sij = reinterpret_cast<__gm__ float *>(sij_raw);

// qi (M, K) fp16 in ND (row-major) layout
using GlobalA = GlobalTensor<half, Shape<1, 1, 1, M, K>, Stride<M * K, M * K, M * K, K, 1>>;
// qi (M, K) bf16 in ND (row-major) layout
using GlobalA = GlobalTensor<bfloat16_t, Shape<1, 1, 1, M, K>, Stride<M * K, M * K, M * K, K, 1>>;
// kj stored as (N, K) row-major = (K, N) column-major -> DN layout
using GlobalB = GlobalTensor<half, Shape<1, 1, 1, K, N>, Stride<K * N, K * N, K * N, 1, K>, Layout::DN>;
using GlobalB = GlobalTensor<bfloat16_t, Shape<1, 1, 1, K, N>, Stride<K * N, K * N, K * N, 1, K>, Layout::DN>;
using GlobalOut = GlobalTensor<float, Shape<1, 1, 1, M, N>, Stride<M * N, M * N, M * N, N, 1>>;

GlobalA qiGlobal(qi);
GlobalB kjGlobal(kj);
GlobalOut sijGlobal(sij);

// L1 Mat tiles: A is standard ND, B uses transposed-B pattern (RowMajor/ColMajor)
using TileMatA = Tile<TileType::Mat, half, M, K, BLayout::ColMajor, M, K, SLayout::RowMajor, 512>;
using TileMatB = Tile<TileType::Mat, half, K, N, BLayout::RowMajor, K, N, SLayout::ColMajor, 512>;
using TileMatA = Tile<TileType::Mat, bfloat16_t, M, K, BLayout::ColMajor, M, K, SLayout::RowMajor, 512>;
using TileMatB = Tile<TileType::Mat, bfloat16_t, K, N, BLayout::RowMajor, K, N, SLayout::ColMajor, 512>;

// L0 tiles
using LeftTile = TileLeft<half, M, K, M, K>;
using RightTile = TileRight<half, K, N, K, N>;
using LeftTile = TileLeft<bfloat16_t, M, K, M, K>;
using RightTile = TileRight<bfloat16_t, K, N, K, N>;
using AccTile = TileAcc<float, M, N, M, N>;

TileMatA aMatTile;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,38 +38,38 @@ static __aicore__ void softmax_prepare_impl(
constexpr int M = 16, N = 16;

__gm__ float *sij = reinterpret_cast<__gm__ float *>(sij_raw);
__gm__ half *pij = reinterpret_cast<__gm__ half *>(pij_raw);
__gm__ bfloat16_t *pij = reinterpret_cast<__gm__ bfloat16_t *>(pij_raw);
__gm__ float *mij = reinterpret_cast<__gm__ float *>(mij_raw);
__gm__ float *lij = reinterpret_cast<__gm__ float *>(lij_raw);

constexpr int kAlignedRows = ((M * sizeof(float) + 31) / 32) * (32 / sizeof(float));

using GlobalDataMxN = GlobalTensor<float, Shape<1, 1, 1, M, N>, Stride<1, 1, 1, N, 1>>;
using GlobalDataMxN_f16 = GlobalTensor<half, Shape<1, 1, 1, M, N>, Stride<1, 1, 1, N, 1>>;
using GlobalDataMxN_bf16 = GlobalTensor<bfloat16_t, Shape<1, 1, 1, M, N>, Stride<1, 1, 1, N, 1>>;
using GlobalScalarDN = GlobalTensor<float, Shape<1, 1, 1, kAlignedRows, 1>, Stride<1, 1, 1, 1, 1>, Layout::DN>;

GlobalDataMxN sijGlobal(sij);
GlobalDataMxN_f16 pijGlobal(pij);
GlobalDataMxN_bf16 pijGlobal(pij);
GlobalScalarDN mijGlobal(mij);
GlobalScalarDN lijGlobal(lij);

using TileVecMxN = Tile<TileType::Vec, float, M, N, BLayout::RowMajor, M, N>;
using TileVecMxN_f16 = Tile<TileType::Vec, half, M, N, BLayout::RowMajor, M, N>;
using TileVecMxN_bf16 = Tile<TileType::Vec, bfloat16_t, M, N, BLayout::RowMajor, M, N>;
using TileScalarDN = Tile<TileType::Vec, float, kAlignedRows, 1, BLayout::ColMajor, M, 1>;

TileVecMxN sijTile;
TileVecMxN pijTile;
TileVecMxN tmpTile;
TileScalarDN maxTile;
TileScalarDN sumTile;
TileVecMxN_f16 pijF16Tile;
TileVecMxN_bf16 pijBf16Tile;

TASSIGN(sijTile, 0x0);
TASSIGN(pijTile, M * N * sizeof(float));
TASSIGN(tmpTile, 2 * M * N * sizeof(float));
TASSIGN(maxTile, 3 * M * N * sizeof(float));
TASSIGN(sumTile, 3 * M * N * sizeof(float) + kAlignedRows * sizeof(float));
TASSIGN(pijF16Tile, 3 * M * N * sizeof(float) + 2 * kAlignedRows * sizeof(float));
TASSIGN(pijBf16Tile, 3 * M * N * sizeof(float) + 2 * kAlignedRows * sizeof(float));

TLOAD(sijTile, sijGlobal);
set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0);
Expand All @@ -79,16 +79,16 @@ static __aicore__ void softmax_prepare_impl(
TROWMAX(maxTile, sijTile, tmpTile);
TROWEXPANDSUB(pijTile, sijTile, maxTile);
TEXP(pijTile, pijTile);
// Truncate pij to fp16 first, then compute lij from truncated values (matches golden)
TCVT(pijF16Tile, pijTile, RoundMode::CAST_ROUND);
TCVT(pijTile, pijF16Tile, RoundMode::CAST_ROUND);
// Truncate pij to bf16 first, then compute lij from truncated values (matches golden)
TCVT(pijBf16Tile, pijTile, RoundMode::CAST_ROUND);
TCVT(pijTile, pijBf16Tile, RoundMode::CAST_ROUND);
TROWSUM(sumTile, pijTile, tmpTile);

set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0);
wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0);
TSTORE(mijGlobal, maxTile);
TSTORE(lijGlobal, sumTile);
TSTORE(pijGlobal, pijF16Tile);
TSTORE(pijGlobal, pijBf16Tile);
}

extern "C" __aicore__ void kernel_entry(__gm__ int64_t *args) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
* Paged Attention Orchestration - Small Scale (16x16)
*
* Supports small-scale paged attention with:
* Query: (batch, q_head_num, head_dim) fp16
* Key: (total_blocks, block_size, kv_head_num, head_dim) fp16 (NOT transposed)
* Value: (total_blocks, block_size, kv_head_num, head_dim) fp16
* Query: (batch, q_head_num, head_dim) bf16
* Key: (total_blocks, block_size, kv_head_num, head_dim) bf16 (NOT transposed)
* Value: (total_blocks, block_size, kv_head_num, head_dim) bf16
* Output: (batch, q_head_num, head_dim) float32
*
* Head tiling: q_tile_size = min(num_heads, 128)
Expand Down Expand Up @@ -148,7 +148,7 @@ int build_paged_attention_graph(OrchestrationRuntime *runtime, const ChipStorage
for (uint32_t ht = 0; ht < num_head_tiles; ht++) {
uint32_t cur_offset = ht * q_tile_size;

// Query: (batch, q_head_num, head_dim) fp16
// Query: (batch, q_head_num, head_dim) bf16
// qi points to heads [cur_offset .. cur_offset+q_tile_size) for batch b_idx
uint8_t *qi_ptr = reinterpret_cast<uint8_t *>(dev_query) +
static_cast<int64_t>(b_idx * num_heads + cur_offset) * head_dim * sizeof(uint16_t);
Expand All @@ -171,12 +171,12 @@ int build_paged_attention_graph(OrchestrationRuntime *runtime, const ChipStorage
for (uint32_t bn = 0; bn < bn_this_batch; bn++) {
int cur_block_idx = host_block_table[b_idx * max_num_blocks + bn];

// Key: (total_blocks, block_size, kv_head_num, head_dim) fp16
// Key: (total_blocks, block_size, kv_head_num, head_dim) bf16
uint8_t *kj_ptr = reinterpret_cast<uint8_t *>(dev_key_cache) +
(static_cast<int64_t>(cur_block_idx) * block_size * kv_head_num + kv_head_idx) *
head_dim * sizeof(uint16_t);

// Value: (total_blocks, block_size, kv_head_num, head_dim) fp16
// Value: (total_blocks, block_size, kv_head_num, head_dim) bf16
uint8_t *vj_ptr = reinterpret_cast<uint8_t *>(dev_value_cache) +
(static_cast<int64_t>(cur_block_idx) * block_size * kv_head_num + kv_head_idx) *
head_dim * sizeof(uint16_t);
Expand Down
6 changes: 3 additions & 3 deletions examples/a5/host_build_graph/paged_attention/golden.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
# See LICENSE in the root of the software repository for the full text of the License.
# -----------------------------------------------------------------------------------------------------------
"""Paged Attention Golden - host_build_graph example (small scale, float16).
"""Paged Attention Golden - host_build_graph example (small scale, bfloat16).

Args layout: [query, key_cache, value_cache, block_table, context_lens, out, scale]
- Tensors retain original multi-dimensional shapes (ContinuousTensor metadata carries shape/dtype)
Expand All @@ -33,7 +33,7 @@
"block_size": 16,
"context_len": 16,
"max_model_len": 256,
"dtype": "float16",
"dtype": "bfloat16",
},
"Case2": {
"batch": 1,
Expand All @@ -43,7 +43,7 @@
"block_size": 16,
"context_len": 64,
"max_model_len": 256,
"dtype": "float16",
"dtype": "bfloat16",
},
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
//
// Fixed tile size: (16, 16) @ (16, 16) -> (16, 16)
//
// pij is float16 (converted from fp32 in softmax_prepare via TCVT).
// pij is bfloat16 (converted from fp32 in softmax_prepare via TCVT).
// vj is stored as (K, N) = (block_size, head_dim) in row-major (ND) layout.
// Standard non-transposed B pattern: ND GlobalB + ColMajor/RowMajor TileMatB.

Expand All @@ -32,26 +32,26 @@ using namespace pto;
static __aicore__ void pv_matmul_impl(__gm__ uint8_t *pij_raw, __gm__ uint8_t *vj_raw, __gm__ uint8_t *oi_raw) {
constexpr int M = 16, K = 16, N = 16;

__gm__ half *pij = reinterpret_cast<__gm__ half *>(pij_raw);
__gm__ half *vj = reinterpret_cast<__gm__ half *>(vj_raw);
__gm__ bfloat16_t *pij = reinterpret_cast<__gm__ bfloat16_t *>(pij_raw);
__gm__ bfloat16_t *vj = reinterpret_cast<__gm__ bfloat16_t *>(vj_raw);
__gm__ float *oi = reinterpret_cast<__gm__ float *>(oi_raw);

// pij (M, K) fp16, vj (K, N) fp16 in ND (row-major), oi_new (M, N) fp32
using GlobalA = GlobalTensor<half, Shape<1, 1, 1, M, K>, pto::Stride<M * K, M * K, M * K, K, 1>>;
using GlobalB = GlobalTensor<half, Shape<1, 1, 1, K, N>, pto::Stride<K * N, K * N, K * N, N, 1>>;
using GlobalOut = GlobalTensor<float, Shape<1, 1, 1, M, N>, pto::Stride<M * N, M * N, M * N, N, 1>>;
// pij (M, K) bf16, vj (K, N) bf16 in ND (row-major), oi_new (M, N) fp32
using GlobalA = GlobalTensor<bfloat16_t, Shape<1, 1, 1, M, K>, Stride<M * K, M * K, M * K, K, 1>>;
using GlobalB = GlobalTensor<bfloat16_t, Shape<1, 1, 1, K, N>, Stride<K * N, K * N, K * N, N, 1>>;
using GlobalOut = GlobalTensor<float, Shape<1, 1, 1, M, N>, Stride<M * N, M * N, M * N, N, 1>>;

GlobalA pijGlobal(pij);
GlobalB vjGlobal(vj);
GlobalOut oiGlobal(oi);

// L1 Mat tiles: standard ND pattern for both A and B
using TileMatA = Tile<TileType::Mat, half, M, K, BLayout::ColMajor, M, K, SLayout::RowMajor, 512>;
using TileMatB = Tile<TileType::Mat, half, K, N, BLayout::ColMajor, K, N, SLayout::RowMajor, 512>;
using TileMatA = Tile<TileType::Mat, bfloat16_t, M, K, BLayout::ColMajor, M, K, SLayout::RowMajor, 512>;
using TileMatB = Tile<TileType::Mat, bfloat16_t, K, N, BLayout::ColMajor, K, N, SLayout::RowMajor, 512>;

// L0 tiles
using LeftTile = TileLeft<half, M, K, M, K>;
using RightTile = TileRight<half, K, N, K, N>;
using LeftTile = TileLeft<bfloat16_t, M, K, M, K>;
using RightTile = TileRight<bfloat16_t, K, N, K, N>;
using AccTile = TileAcc<float, M, N, M, N>;

TileMatA aMatTile;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,27 +32,27 @@ using namespace pto;
static __aicore__ void qk_matmul_impl(__gm__ uint8_t *qi_raw, __gm__ uint8_t *kj_raw, __gm__ uint8_t *sij_raw) {
constexpr int M = 16, K = 16, N = 16;

__gm__ half *qi = reinterpret_cast<__gm__ half *>(qi_raw);
__gm__ half *kj = reinterpret_cast<__gm__ half *>(kj_raw);
__gm__ bfloat16_t *qi = reinterpret_cast<__gm__ bfloat16_t *>(qi_raw);
__gm__ bfloat16_t *kj = reinterpret_cast<__gm__ bfloat16_t *>(kj_raw);
__gm__ float *sij = reinterpret_cast<__gm__ float *>(sij_raw);

// qi (M, K) fp16 in ND (row-major) layout
using GlobalA = GlobalTensor<half, Shape<1, 1, 1, M, K>, pto::Stride<M * K, M * K, M * K, K, 1>>;
// qi (M, K) bf16 in ND (row-major) layout
using GlobalA = GlobalTensor<bfloat16_t, Shape<1, 1, 1, M, K>, Stride<M * K, M * K, M * K, K, 1>>;
// kj stored as (N, K) row-major = (K, N) column-major -> DN layout
using GlobalB = GlobalTensor<half, Shape<1, 1, 1, K, N>, pto::Stride<K * N, K * N, K * N, 1, K>, Layout::DN>;
using GlobalOut = GlobalTensor<float, Shape<1, 1, 1, M, N>, pto::Stride<M * N, M * N, M * N, N, 1>>;
using GlobalB = GlobalTensor<bfloat16_t, Shape<1, 1, 1, K, N>, Stride<K * N, K * N, K * N, 1, K>, Layout::DN>;
using GlobalOut = GlobalTensor<float, Shape<1, 1, 1, M, N>, Stride<M * N, M * N, M * N, N, 1>>;

GlobalA qiGlobal(qi);
GlobalB kjGlobal(kj);
GlobalOut sijGlobal(sij);

// L1 Mat tiles: A is standard ND, B uses transposed-B pattern (RowMajor/ColMajor)
using TileMatA = Tile<TileType::Mat, half, M, K, BLayout::ColMajor, M, K, SLayout::RowMajor, 512>;
using TileMatB = Tile<TileType::Mat, half, K, N, BLayout::RowMajor, K, N, SLayout::ColMajor, 512>;
using TileMatA = Tile<TileType::Mat, bfloat16_t, M, K, BLayout::ColMajor, M, K, SLayout::RowMajor, 512>;
using TileMatB = Tile<TileType::Mat, bfloat16_t, K, N, BLayout::RowMajor, K, N, SLayout::ColMajor, 512>;

// L0 tiles
using LeftTile = TileLeft<half, M, K, M, K>;
using RightTile = TileRight<half, K, N, K, N>;
using LeftTile = TileLeft<bfloat16_t, M, K, M, K>;
using RightTile = TileRight<bfloat16_t, K, N, K, N>;
using AccTile = TileAcc<float, M, N, M, N>;

TileMatA aMatTile;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,38 +38,38 @@ static __aicore__ void softmax_prepare_impl(
constexpr int M = 16, N = 16;

__gm__ float *sij = reinterpret_cast<__gm__ float *>(sij_raw);
__gm__ half *pij = reinterpret_cast<__gm__ half *>(pij_raw);
__gm__ bfloat16_t *pij = reinterpret_cast<__gm__ bfloat16_t *>(pij_raw);
__gm__ float *mij = reinterpret_cast<__gm__ float *>(mij_raw);
__gm__ float *lij = reinterpret_cast<__gm__ float *>(lij_raw);

constexpr int kAlignedRows = ((M * sizeof(float) + 31) / 32) * (32 / sizeof(float));

using GlobalDataMxN = GlobalTensor<float, Shape<1, 1, 1, M, N>, pto::Stride<1, 1, 1, N, 1>>;
using GlobalDataMxN_f16 = GlobalTensor<half, Shape<1, 1, 1, M, N>, pto::Stride<1, 1, 1, N, 1>>;
using GlobalDataMxN_bf16 = GlobalTensor<bfloat16_t, Shape<1, 1, 1, M, N>, Stride<1, 1, 1, N, 1>>;
using GlobalScalarDN = GlobalTensor<float, Shape<1, 1, 1, kAlignedRows, 1>, pto::Stride<1, 1, 1, 1, 1>, Layout::DN>;

GlobalDataMxN sijGlobal(sij);
GlobalDataMxN_f16 pijGlobal(pij);
GlobalDataMxN_bf16 pijGlobal(pij);
GlobalScalarDN mijGlobal(mij);
GlobalScalarDN lijGlobal(lij);

using TileVecMxN = Tile<TileType::Vec, float, M, N, BLayout::RowMajor, M, N>;
using TileVecMxN_f16 = Tile<TileType::Vec, half, M, N, BLayout::RowMajor, M, N>;
using TileVecMxN_bf16 = Tile<TileType::Vec, bfloat16_t, M, N, BLayout::RowMajor, M, N>;
using TileScalarDN = Tile<TileType::Vec, float, kAlignedRows, 1, BLayout::ColMajor, M, 1>;

TileVecMxN sijTile;
TileVecMxN pijTile;
TileVecMxN tmpTile;
TileScalarDN maxTile;
TileScalarDN sumTile;
TileVecMxN_f16 pijF16Tile;
TileVecMxN_bf16 pijBf16Tile;

TASSIGN(sijTile, 0x0);
TASSIGN(pijTile, M * N * sizeof(float));
TASSIGN(tmpTile, 2 * M * N * sizeof(float));
TASSIGN(maxTile, 3 * M * N * sizeof(float));
TASSIGN(sumTile, 3 * M * N * sizeof(float) + kAlignedRows * sizeof(float));
TASSIGN(pijF16Tile, 3 * M * N * sizeof(float) + 2 * kAlignedRows * sizeof(float));
TASSIGN(pijBf16Tile, 3 * M * N * sizeof(float) + 2 * kAlignedRows * sizeof(float));

TLOAD(sijTile, sijGlobal);
set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0);
Expand All @@ -79,16 +79,16 @@ static __aicore__ void softmax_prepare_impl(
TROWMAX(maxTile, sijTile, tmpTile);
TROWEXPANDSUB(pijTile, sijTile, maxTile);
TEXP(pijTile, pijTile);
// Truncate pij to fp16 first, then compute lij from truncated values (matches golden)
TCVT(pijF16Tile, pijTile, RoundMode::CAST_ROUND);
TCVT(pijTile, pijF16Tile, RoundMode::CAST_ROUND);
// Truncate pij to bf16 first, then compute lij from truncated values (matches golden)
TCVT(pijBf16Tile, pijTile, RoundMode::CAST_ROUND);
TCVT(pijTile, pijBf16Tile, RoundMode::CAST_ROUND);
TROWSUM(sumTile, pijTile, tmpTile);

set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0);
wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0);
TSTORE(mijGlobal, maxTile);
TSTORE(lijGlobal, sumTile);
TSTORE(pijGlobal, pijF16Tile);
TSTORE(pijGlobal, pijBf16Tile);
}

extern "C" __aicore__ void kernel_entry(__gm__ int64_t *args) {
Expand Down
Loading
Loading