Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions kernel/include/vx_intrinsics.h
Original file line number Diff line number Diff line change
Expand Up @@ -287,34 +287,34 @@ inline __attribute__((const)) int vx_shfl_idx(size_t value, int bval, int cval,

// TILE LOAD T: Load 1KB from ptr[TILE] to tile register index 'dst_treg'
// Each load uses I-type encoding: rd=dst tile index, rs1=src_gpr, imm=ptr immediate
Comment on lines 288 to 289
Copy link

Copilot AI Dec 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment says "Load 1KB" but T-regs are 2KB (2048 bytes) as defined in common.h. Update the comment to reflect the correct size: "Load 2KB from ptr[TILE] to tile register index 'dst_treg'".

Copilot uses AI. Check for mistakes.
inline void vx_lt(int dst_treg, int src_gpr, size_t ptr_imm) {
inline void vx_lt(int dst_treg, size_t src_gpr, size_t ptr_imm) {
__asm__ volatile (".insn i %0, 0, x%1, %2, %3"
:: "i"(RISCV_CUSTOM1), "i"(dst_treg), "r"(src_gpr), "i"(ptr_imm) : "memory");
}

// TILE LOAD U: Load 1KB from ptr[TILE] to ureg index 'dst_ureg'
Copy link

Copilot AI Dec 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment says "Load 1KB" but U-regs are 4KB (4096 bytes) as defined in common.h. Update the comment to reflect the correct size: "Load 4KB from ptr[TILE] to ureg index 'dst_ureg'".

Suggested change
// TILE LOAD U: Load 1KB from ptr[TILE] to ureg index 'dst_ureg'
// TILE LOAD U: Load 4KB from ptr[TILE] to ureg index 'dst_ureg'

Copilot uses AI. Check for mistakes.
inline void vx_lu(int dst_ureg, int src_gpr, size_t ptr_imm) {
inline void vx_lu(int dst_ureg, size_t src_gpr, size_t ptr_imm) {
__asm__ volatile (".insn i %0, 1, x%1, %2, %3"
:: "i"(RISCV_CUSTOM1), "i"(dst_ureg), "r"(src_gpr), "i"(ptr_imm) : "memory");
}

// TILE LOAD V: Load 1KB from ptr[TILE] to vreg index 'dst_vreg'
Copy link

Copilot AI Dec 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment says "Load 1KB" but V-regs are 8KB (8192 bytes) as defined in common.h. Update the comment to reflect the correct size: "Load 8KB from ptr[TILE] to vreg index 'dst_vreg'".

Suggested change
// TILE LOAD V: Load 1KB from ptr[TILE] to vreg index 'dst_vreg'
// TILE LOAD V: Load 8KB from ptr[TILE] to vreg index 'dst_vreg'

Copilot uses AI. Check for mistakes.
inline void vx_lv(int dst_vreg, int src_gpr, size_t ptr_imm) {
inline void vx_lv(int dst_vreg, size_t src_gpr, size_t ptr_imm) {
__asm__ volatile (".insn i %0, 2, x%1, %2, %3"
:: "i"(RISCV_CUSTOM1), "i"(dst_vreg), "r"(src_gpr), "i"(ptr_imm) : "memory");
}

// TILE LOAD M: Load 1KB from ptr[TILE] to mreg index 'dst_mreg'
Copy link

Copilot AI Dec 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment says "Load 1KB" but M-regs are 256 bytes as defined in common.h. Update the comment to reflect the correct size: "Load 256B from ptr[TILE] to mreg index 'dst_mreg'".

Suggested change
// TILE LOAD M: Load 1KB from ptr[TILE] to mreg index 'dst_mreg'
// TILE LOAD M: Load 256B from ptr[TILE] to mreg index 'dst_mreg'

Copilot uses AI. Check for mistakes.
inline void vx_lm(int dst_mreg, int src_gpr, size_t ptr_imm) {
inline void vx_lm(int dst_mreg, size_t src_gpr, size_t ptr_imm) {
__asm__ volatile (".insn i %0, 3, x%1, %2, %3"
:: "i"(RISCV_CUSTOM1), "i"(dst_mreg), "r"(src_gpr), "i"(ptr_imm) : "memory");
}

// TILE STORE T: Store 1KB from treg index 'src_treg' to ptr[TILE]
Copy link

Copilot AI Dec 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment says "Store 1KB" but T-regs are 2KB (2048 bytes) as defined in common.h. Update the comment to reflect the correct size: "Store 2KB from treg index 'src_treg' to ptr[TILE]".

Suggested change
// TILE STORE T: Store 1KB from treg index 'src_treg' to ptr[TILE]
// TILE STORE T: Store 2KB from treg index 'src_treg' to ptr[TILE]

Copilot uses AI. Check for mistakes.
// Store uses S-type encoding: rs1=src_gpr, rs2=src_treg index, imm=ptr immediate
inline void vx_st(int src_gpr, size_t ptr_imm, int src_treg) {
__asm__ volatile (".insn s %0, 0, %1, x%2, %3"
:: "i"(RISCV_CUSTOM2), "r"(src_gpr), "i"(src_treg), "i"(ptr_imm) : "memory");
inline void vx_st(size_t src_gpr, size_t ptr_imm, int src_treg) {
__asm__ volatile (".insn s %0, 0, x%3, %2(%1)"
:: "i"(RISCV_CUSTOM2), "r"(src_gpr), "i"(ptr_imm), "i"(src_treg) : "memory");
}

// -----------------------------------------------------------------------------
Expand Down
127 changes: 86 additions & 41 deletions kernel/include/vx_sparse.h
Original file line number Diff line number Diff line change
Expand Up @@ -200,57 +200,102 @@ struct wmma_context {
if constexpr (src_layout == col_major) {
std::swap(block_row, block_col);
}
// For sparse format: when meta_src is provided, data stride is K/2 (not K)
// because each row has K/2 values (2 per block of 4)
size_t data_ldm = (meta_src != nullptr) ? (ldm / 2) : ldm;
auto base = reinterpret_cast<const input_t*>(src) + block_row * data_ldm + block_col;

// Metadata pointer is pre-offset to tile position (like data pointer)
// For metadata: stride is based on number of K-blocks per row in the FULL matrix
// This is ldm/4 (K/4), not affected by tile boundaries
const uint32_t* meta_base = meta_src ? reinterpret_cast<const uint32_t*>(meta_src) : nullptr;
// NOTE: meta_ldm uses full matrix K for stride, not tile dimensions
uint32_t meta_ldm = meta_src ? (ldm / 4) : 0;

detail::unroll_for<Frag::NR>([&](auto r) {
uint32_t block_m = r / cfg::k_steps;
uint32_t block_k = r % cfg::k_steps;
uint32_t elem_row = block_m * m_stride;
uint32_t elem_col = block_k * k_stride;
uint32_t meta_value = 0;

if (meta_base) {
// Metadata uses ABSOLUTE matrix positions (not tile-relative)
// meta_row_base = tile_row (absolute row offset for this tile)
// meta_col_base = k_tile (absolute K offset for this tile)
uint32_t abs_row = meta_row_base + block_row + elem_row;
uint32_t abs_k_block = (meta_col_base / 4) + block_k; // K-block index in full matrix
if (meta_src != nullptr) {
// SPARSE LOADING: Use metadata to place values in correct k_step registers
// data_ldm is K/2 for sparse (compressed values)
size_t data_ldm = ldm / 2;
// For sparse, don't add block_col to base - we compute sparse_idx separately
auto data_base = reinterpret_cast<const input_t*>(src) + block_row * data_ldm;
// First, load metadata for each M row that this thread handles
// and distribute sparse values to the correct k_step registers
detail::unroll_for<Frag::NR>([&](auto r) {
uint32_t block_m = r / cfg::k_steps;
uint32_t block_k = r % cfg::k_steps;
uint32_t elem_row = block_m * m_stride;

// Metadata is stored in row-major format with meta_ldm entries per row
// Get metadata for this row (absolute position in matrix)
uint32_t abs_row = meta_row_base + block_row + elem_row;
uint32_t abs_k_block = (meta_col_base / 4); // K-block index for this tile
const uint32_t *meta_ptr = meta_base + static_cast<size_t>(abs_row) * meta_ldm + abs_k_block;
meta_value = *meta_ptr;
}

if constexpr (Frag::Use == matrix_a) {
uint32_t meta_value = *meta_ptr;
dst.metadata[r] = meta_value;
}
if constexpr (src_layout == col_major) {
static_assert(input_is_subbyte == false, "col_major layout is not supported for sub-byte matrix_a");
std::swap(elem_row, elem_col);
auto ptr = base + elem_row * data_ldm + elem_col;
if constexpr (sizeof(vreg_t) == sizeof(input_t) && !input_is_subbyte) {
dst.data[r] = *reinterpret_cast<const vreg_t*>(ptr);

// meta_value is a bitmask: bits 0-3 indicate which of 4 K positions have values
// block_k indicates which pair of K positions this register is for:
// block_k=0 -> K positions 0,1 (bits 0,1)
// block_k=1 -> K positions 2,3 (bits 2,3)
uint8_t meta_byte = meta_value & 0xFF;
uint32_t k_start = block_k * cfg::tcK; // Start K position for this register
uint32_t k_end = k_start + cfg::tcK; // End K position

// Count how many sparse values come BEFORE this k_step for this row
uint32_t sparse_offset = 0;
for (uint32_t pos = 0; pos < k_start; ++pos) {
if (meta_byte & (1u << pos)) {
sparse_offset++;
}
}

// For fp32 with tcK=2, each register holds 1 fp32 value
// block_col determines which position within the tcK pair: 0 or 1
// So the target K position is: k_start + block_col
uint32_t target_pos = k_start + block_col;

vreg_t loaded_val = 0.0f;

if (target_pos < 4) {
// Count sparse values before target_pos to get the sparse index
uint32_t sparse_idx = sparse_offset;
for (uint32_t pos = k_start; pos < target_pos; ++pos) {
if (meta_byte & (1u << pos)) {
sparse_idx++;
}
}

// Check if target position has a sparse value
if (meta_byte & (1u << target_pos)) {
auto ptr = data_base + elem_row * data_ldm + sparse_idx;
loaded_val = *ptr;
}
// else: loaded_val stays 0.0f (position was pruned)
}

dst.data[r] = loaded_val;
});
} else {
// DENSE LOADING: Original non-sparse path
auto base = reinterpret_cast<const input_t*>(src) + block_row * ldm + block_col;

detail::unroll_for<Frag::NR>([&](auto r) {
uint32_t block_m = r / cfg::k_steps;
uint32_t block_k = r % cfg::k_steps;
uint32_t elem_row = block_m * m_stride;
uint32_t elem_col = block_k * k_stride;

dst.metadata[r] = 0;

if constexpr (src_layout == col_major) {
static_assert(input_is_subbyte == false, "col_major layout is not supported for sub-byte matrix_a");
std::swap(elem_row, elem_col);
auto ptr = base + elem_row * ldm + elem_col;
if constexpr (sizeof(vreg_t) == sizeof(input_t) && !input_is_subbyte) {
dst.data[r] = *reinterpret_cast<const vreg_t*>(ptr);
} else {
dst.data[r] = input_acessor_t::pack_row(ptr, ldm);
}
} else {
dst.data[r] = input_acessor_t::pack_row(ptr, data_ldm);
auto ptr = base + elem_row * ldm + elem_col;
assert(reinterpret_cast<uintptr_t>(ptr) % alignof(vreg_t) == 0 && "pointer must be aligned to 4 bytes");
dst.data[r] = *reinterpret_cast<const vreg_t *>(ptr);
}
} else {
// row_major layout
// For sparse format, use data_ldm (K/2) instead of ldm (K)
auto ptr = base + elem_row * data_ldm + elem_col;
assert(reinterpret_cast<uintptr_t>(ptr) % alignof(vreg_t) == 0 && "pointer must be aligned to 4 bytes");
dst.data[r] = *reinterpret_cast<const vreg_t *>(ptr);
}
});
});
}
} else if constexpr (Frag::Use == matrix_b) {
// Load column-major matrix B
uint32_t block_idx = (cfg::b_block_size == NT) ? 0 : (lane / cfg::b_block_size);
Expand Down
3 changes: 3 additions & 0 deletions sim/simx/core.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,9 @@ void Core::issue() {
#endif
#ifdef EXT_TCU_ENABLE
case FUType::TCU: ++perf_stats_.scrb_tcu; break;
#endif
#ifdef EXT_VEGETA_ENABLE
case FUType::VEGETA: ++perf_stats_.scrb_vegeta; break;
#endif
default: assert(false);
}
Expand Down
6 changes: 6 additions & 0 deletions sim/simx/core.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ class Core : public SimObject<Core> {
#endif
#ifdef EXT_TCU_ENABLE
uint64_t scrb_tcu;
#endif
#ifdef EXT_VEGETA_ENABLE
uint64_t scrb_vegeta;
#endif
uint64_t ifetches;
uint64_t loads;
Expand Down Expand Up @@ -93,6 +96,9 @@ class Core : public SimObject<Core> {
#endif
#ifdef EXT_TCU_ENABLE
, scrb_tcu(0)
#endif
#ifdef EXT_VEGETA_ENABLE
, scrb_vegeta(0)
#endif
, ifetches(0)
, loads(0)
Expand Down
2 changes: 0 additions & 2 deletions sim/simx/decode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1149,7 +1149,6 @@ void Emulator::decode(uint32_t code, uint32_t wid, uint64_t uuid) {
} break;
#endif
#ifdef EXT_VEGETA_ENABLE

case 3: {
switch (funct3) {
case 0: { // WMMA
Expand Down Expand Up @@ -1190,7 +1189,6 @@ void Emulator::decode(uint32_t code, uint32_t wid, uint64_t uuid) {
std::abort();
}
} break;

#endif
default:
std::abort();
Expand Down
3 changes: 3 additions & 0 deletions sim/simx/emulator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,9 @@ Word Emulator::get_csr(uint32_t addr, uint32_t wid, uint32_t tid) {
#endif
#ifdef EXT_VPU_ENABLE
CSR_READ_64(VX_CSR_MPM_SCRB_TCU, core_perf.scrb_vpu);
#endif
#ifdef EXT_VEGETA_ENABLE
CSR_READ_64(VX_CSR_MPM_SCRB_TCU, core_perf.scrb_vegeta);
Copy link

Copilot AI Dec 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using wrong CSR identifier. This line reads from VX_CSR_MPM_SCRB_TCU but stores to core_perf.scrb_vegeta. All three conditional blocks (lines 500, 503, 506) are reading from the same CSR, which will cause incorrect performance counter values. Change to use a dedicated VEGETA CSR identifier, such as VX_CSR_MPM_SCRB_VEGETA (if it exists), or verify the correct CSR address for VEGETA scoreboard statistics.

Suggested change
CSR_READ_64(VX_CSR_MPM_SCRB_TCU, core_perf.scrb_vegeta);
CSR_READ_64(VX_CSR_MPM_SCRB_VEGETA, core_perf.scrb_vegeta);

Copilot uses AI. Check for mistakes.
#endif
CSR_READ_64(VX_CSR_MPM_SCRB_CSRS, core_perf.scrb_csrs);
CSR_READ_64(VX_CSR_MPM_SCRB_WCTL, core_perf.scrb_wctl);
Expand Down
72 changes: 64 additions & 8 deletions sim/simx/execute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,15 @@ instr_trace_t* Emulator::execute(const Instr &instr, uint32_t wid) {
<< ", PC=0x" << std::hex << warp.PC << std::dec << " (#" << instr.getUUID() << ")");

// fetch register values
#ifdef EXT_VEGETA_ENABLE
if (rsrc0.type != RegType::None && rsrc0.type != RegType::Tile) fetch_registers(rs1_data, wid, 0, rsrc0);
if (rsrc1.type != RegType::None && rsrc1.type != RegType::Tile) fetch_registers(rs2_data, wid, 1, rsrc1);
if (rsrc2.type != RegType::None && rsrc2.type != RegType::Tile) fetch_registers(rs3_data, wid, 2, rsrc2);
#else
if (rsrc0.type != RegType::None) fetch_registers(rs1_data, wid, 0, rsrc0);
if (rsrc1.type != RegType::None) fetch_registers(rs2_data, wid, 1, rsrc1);
if (rsrc2.type != RegType::None) fetch_registers(rs3_data, wid, 2, rsrc2);
#endif

uint32_t thread_start = 0;
for (; thread_start < num_threads; ++thread_start) {
Expand Down Expand Up @@ -1546,16 +1552,66 @@ instr_trace_t* Emulator::execute(const Instr &instr, uint32_t wid) {
}
},
[&](VegetaTcuType tcu_type) {
auto tpuArgs = std::get<IntrVegetaTcuArgs>(instrArgs);
switch (tcu_type) {
case VegetaTcuType::TILE_GEMM_T:
case VegetaTcuType::TILE_GEMM_U:
case VegetaTcuType::TILE_GEMM_V:
case VegetaTcuType::TILE_GEMM_R:
// TODO: Implement TILE_GEMM execution
std::abort();
break;
case VegetaTcuType::TILE_GEMM_T: {
auto trace_data = std::make_shared<SparseUnit::ExeTraceData>();
trace->data = trace_data;
assert(warp.tmask.count() == num_threads);

// Extract tile register indices from instruction
uint32_t dst_reg = rdest.idx;
uint32_t src1_reg = instr.getSrcReg(0).idx;
uint32_t src2_reg = instr.getSrcReg(1).idx;

// Dense tile × Dense tile → Tile (T × T → T)
sparse_unit_->tile_gemm_t(dst_reg, src1_reg, src2_reg);
rd_write = false; // Writes to tile registers, not scalar registers
} break;
case VegetaTcuType::TILE_GEMM_U: {
auto trace_data = std::make_shared<SparseUnit::ExeTraceData>();
trace->data = trace_data;
assert(warp.tmask.count() == num_threads);

// Extract tile register indices from instruction
uint32_t dst_reg = rdest.idx;
uint32_t src1_reg = instr.getSrcReg(0).idx;
uint32_t src2_reg = instr.getSrcReg(1).idx;

// Sparse tile (2:4) × Dense tile → Tile (T × U → T)
// Metadata assumed to be in corresponding m-register (same index as src1)
sparse_unit_->tile_gemm_u(dst_reg, src1_reg, src2_reg, src1_reg);
rd_write = false;
} break;
case VegetaTcuType::TILE_GEMM_V: {
auto trace_data = std::make_shared<SparseUnit::ExeTraceData>();
trace->data = trace_data;
assert(warp.tmask.count() == num_threads);

// Extract tile register indices from instruction
uint32_t dst_reg = rdest.idx;
uint32_t src1_reg = instr.getSrcReg(0).idx;
uint32_t src2_reg = instr.getSrcReg(1).idx;

// Sparse tile (1:4) × Dense tile → Tile (T × V → T)
sparse_unit_->tile_gemm_v(dst_reg, src1_reg, src2_reg, src1_reg);
rd_write = false;
} break;
case VegetaTcuType::TILE_GEMM_R: {
auto trace_data = std::make_shared<SparseUnit::ExeTraceData>();
trace->data = trace_data;
assert(warp.tmask.count() == num_threads);

// Extract tile register indices from instruction
uint32_t dst_reg = rdest.idx;
uint32_t src1_reg = instr.getSrcReg(0).idx;
uint32_t src2_reg = instr.getSrcReg(1).idx;

// Row-wise sparse tile × Dense tile → Tile (T × U → U)
sparse_unit_->tile_gemm_r(dst_reg, src1_reg, src2_reg, src1_reg);
rd_write = false;
} break;
case VegetaTcuType::WMMA: {
auto tpuArgs = std::get<IntrVegetaTcuArgs>(instrArgs);
auto trace_data = std::make_shared<SparseUnit::ExeTraceData>();
trace->data = trace_data;
assert(warp.tmask.count() == num_threads);
Expand Down
Loading