Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
d73a042
Add sgemm_tcu_struct_sparse test
yanggon-kim Feb 5, 2026
93752d2
Add sparse TCU support: VX_tcu_meta module and B-column mux
yanggon-kim Feb 6, 2026
a580a6c
Add sparse TCU support: B-column mux with VX_tcu_sel module
yanggon-kim Feb 6, 2026
aaa4a53
changed the cpu_ref function
yanggon-kim Feb 6, 2026
5164075
randomize the operands, fix the rtl index for b_col_1, b_col_2.
yanggon-kim Feb 6, 2026
de3cd7a
fp16 fp32 printing. This code works for int8 and int32
yanggon-kim Feb 7, 2026
7a125dd
fp16/fp32 done by claude
yanggon-kim Feb 7, 2026
7630e3b
all pass with claude code
yanggon-kim Feb 8, 2026
815ee7c
after all config passes, test the 0101/1010 two pattern sweap pass
yanggon-kim Feb 8, 2026
eda31a1
new instruction working mma_struct_sparse_sync by claude code
yanggon-kim Feb 8, 2026
bfcf24b
prune and compress with fixed mast, fix matmul_cpu
yanggon-kim Feb 8, 2026
833bacf
code minimization with same functionality
yanggon-kim Feb 8, 2026
edd3361
loop code change
yanggon-kim Feb 10, 2026
1164bfe
NT=16 problem
yanggon-kim Feb 13, 2026
6a203ac
meta_store new SRAM feeding instruction
yanggon-kim Feb 16, 2026
5bc74fb
real meta, dynamic meta generation and run
yanggon-kim Feb 16, 2026
348187e
separate the tcu only time using csr hardware count
yanggon-kim Feb 16, 2026
0613bb2
past NT=16 clean up
yanggon-kim Feb 18, 2026
741750a
comment from professor
yanggon-kim Feb 20, 2026
9323558
Merge remote-tracking branch 'upstream/bug_fixes' into rtlsim_260203
yanggon-kim Feb 20, 2026
9a2c32d
fix verilator lint warning for vld_mask after upstream merge
yanggon-kim Feb 20, 2026
8bfc6b6
fix SimX get_barrier_phase for global barriers
yanggon-kim Feb 20, 2026
565bd0f
Merge upstream/bug_fixes into pr_sparse_tcu_merge
yanggon-kim Feb 24, 2026
9cc55a4
fix VX_tcu_meta address: use generate-if bit-concatenation instead of…
yanggon-kim Feb 24, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions hw/rtl/VX_gpu_pkg.sv
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,9 @@ package VX_gpu_pkg;

`ifdef EXT_TCU_ENABLE

localparam INST_TCU_WMMA = 4'h0;
localparam INST_TCU_WMMA = 4'h0;
localparam INST_TCU_WMMA_SP = 4'h1;
localparam INST_TCU_META_STORE = 4'h2;
localparam INST_TCU_BITS = 4;

`endif
Expand Down Expand Up @@ -569,9 +571,10 @@ package VX_gpu_pkg;

`ifdef EXT_TCU_ENABLE
typedef struct packed {
logic [(INST_ARGS_BITS-16)-1:0] __padding;
logic [(INST_ARGS_BITS-20)-1:0] __padding;
logic [3:0] fmt_d;
logic [3:0] fmt_s;
logic [3:0] step_k;
logic [3:0] step_n;
logic [3:0] step_m;
} tcu_args_t;
Expand Down
38 changes: 23 additions & 15 deletions hw/rtl/core/VX_decode.sv
Original file line number Diff line number Diff line change
Expand Up @@ -555,21 +555,29 @@ module VX_decode import VX_gpu_pkg::*; #(
`endif
`ifdef EXT_TCU_ENABLE
7'h02: begin
case (funct3)
3'h0: begin // WMMA_SYNC
ex_type = EX_TCU;
op_type = INST_OP_BITS'(INST_TCU_WMMA);
op_args.tcu.fmt_s = rs1[3:0];
op_args.tcu.fmt_d = rd[3:0];
op_args.tcu.step_m = '0;
op_args.tcu.step_n = '0;
`USED_FREG (rd);
`USED_FREG (rs1);
`USED_FREG (rs2);
`USED_FREG (rs3);
end
default:;
endcase
if (funct3 == 3'h0 || funct3 == 3'h1) begin
ex_type = EX_TCU;
op_type = funct3[0] ? INST_OP_BITS'(INST_TCU_WMMA_SP)
: INST_OP_BITS'(INST_TCU_WMMA);
op_args.tcu.fmt_s = rs1[3:0];
op_args.tcu.fmt_d = rd[3:0];
op_args.tcu.step_m = '0;
op_args.tcu.step_n = '0;
op_args.tcu.step_k = '0;
`USED_FREG (rd);
`USED_FREG (rs1);
`USED_FREG (rs2);
`USED_FREG (rs3);
end else if (funct3 == 3'h2) begin
ex_type = EX_TCU;
op_type = INST_OP_BITS'(INST_TCU_META_STORE);
op_args.tcu.fmt_d = rd[3:0]; // col_idx
op_args.tcu.fmt_s = '0;
op_args.tcu.step_m = '0;
op_args.tcu.step_n = '0;
op_args.tcu.step_k = '0;
`USED_FREG (rs1); // source float register
end
end
`endif
default:;
Expand Down
4 changes: 3 additions & 1 deletion hw/rtl/core/VX_uop_sequencer.sv
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@ module VX_uop_sequencer import

`ifdef EXT_TCU_ENABLE

assign is_base_uop_input = (input_if.data.ex_type == EX_TCU && input_if.data.op_type == INST_TCU_WMMA);
assign is_base_uop_input = (input_if.data.ex_type == EX_TCU
&& (input_if.data.op_type == INST_TCU_WMMA
|| input_if.data.op_type == INST_TCU_WMMA_SP));

VX_tcu_uops tcu_uops (
.clk (clk),
Expand Down
92 changes: 83 additions & 9 deletions hw/rtl/tcu/VX_tcu_core.sv
Original file line number Diff line number Diff line change
Expand Up @@ -56,17 +56,41 @@ module VX_tcu_core import VX_gpu_pkg::*, VX_tcu_pkg::*; #(
localparam PIPE_LATENCY = FEDP_LATENCY + 1;
localparam MDATA_QUEUE_DEPTH = 1 << $clog2(PIPE_LATENCY);

localparam LG_A_BS = $clog2(TCU_A_BLOCK_SIZE);
localparam LG_B_BS = $clog2(TCU_B_BLOCK_SIZE);
localparam OFF_W = $clog2(TCU_BLOCK_CAP);
localparam LG_A_BS = $clog2(TCU_A_BLOCK_SIZE);
localparam LG_B_BS = $clog2(TCU_B_BLOCK_SIZE);
localparam LG_B_BS_SP = $clog2(TCU_B_BLOCK_SIZE_SP);
localparam OFF_W = $clog2(TCU_BLOCK_CAP);

wire is_sparse = (execute_if.data.op_type == INST_TCU_WMMA_SP);
wire is_meta_store = (execute_if.data.op_type == INST_TCU_META_STORE);

wire [3:0] step_m = execute_if.data.op_args.tcu.step_m;
wire [3:0] step_n = execute_if.data.op_args.tcu.step_n;
wire [3:0] step_k = execute_if.data.op_args.tcu.step_k;

wire [3:0] fmt_s = execute_if.data.op_args.tcu.fmt_s;
wire [3:0] fmt_d = execute_if.data.op_args.tcu.fmt_d;

`UNUSED_VAR ({step_m, step_n, fmt_s, fmt_d, execute_if.data});
wire [`LOG2UP(`NUM_WARPS)-1:0] wid = execute_if.data.header.wid;

// meta_store: extract per-row write data from rs1_data lanes
localparam PER_WARP_DEPTH = TCU_M_STEPS * (TCU_K_STEPS / 2);
wire meta_wr_en = execute_fire && is_meta_store;
wire [PER_WARP_DEPTH-1:0][31:0] meta_wr_data;
for (genvar r = 0; r < PER_WARP_DEPTH; ++r) begin : g_meta_wr
assign meta_wr_data[r] = 32'(execute_if.data.rs1_data[r]);
end

// meta_store: force rd=0 in mdata_queue header (x0 write is harmless)
tcu_header_t mdata_queue_in;
always_comb begin
mdata_queue_in = execute_if.data.header;
if (is_meta_store) begin
mdata_queue_in.rd = '0;
end
end

`UNUSED_VAR ({step_m, step_n, step_k, fmt_s, fmt_d, execute_if.data});

wire mdata_queue_full;

Expand Down Expand Up @@ -103,7 +127,7 @@ module VX_tcu_core import VX_gpu_pkg::*, VX_tcu_pkg::*; #(
.reset (reset),
.push (execute_fire),
.pop (result_fire),
.data_in(execute_if.data.header),
.data_in(mdata_queue_in),
.data_out(result_if.data.header),
`UNUSED_PIN(empty),
`UNUSED_PIN(alm_empty),
Expand All @@ -113,18 +137,68 @@ module VX_tcu_core import VX_gpu_pkg::*, VX_tcu_pkg::*; #(
);

wire [OFF_W-1:0] a_off = (OFF_W'(step_m) & OFF_W'(TCU_A_SUB_BLOCKS-1)) << LG_A_BS;
wire [OFF_W-1:0] b_off = (OFF_W'(step_n) & OFF_W'(TCU_B_SUB_BLOCKS-1)) << LG_B_BS;
wire [OFF_W-1:0] b_off = is_sparse
? (OFF_W'(step_n) & OFF_W'(TCU_B_SUB_BLOCKS_SP-1)) << LG_B_BS_SP
: (OFF_W'(step_n) & OFF_W'(TCU_B_SUB_BLOCKS-1)) << LG_B_BS;

wire [TCU_TC_M-1:0][TCU_TC_N-1:0][31:0] d_val;

// 2:4 sparsity metadata
`ifndef TCU_ITYPE_BITS
`define TCU_ITYPE_BITS 8
`endif
localparam I_RATIO = 32 / `TCU_ITYPE_BITS; // Elements per 32-bit word
localparam META_BLOCK_WIDTH = TCU_NT * 2 * I_RATIO;
localparam META_ROW_WIDTH = TCU_TC_K * 2 * I_RATIO;
localparam ELT_W = 32 / I_RATIO; // bits per element (8 for int8)
wire [META_BLOCK_WIDTH-1:0] vld_meta_block;

VX_tcu_meta #(
.INSTANCE_ID (INSTANCE_ID),
.META_BLOCK_WIDTH(META_BLOCK_WIDTH),
.PER_WARP_DEPTH (PER_WARP_DEPTH)
) tcu_meta (
.clk (clk),
.reset (reset),
.raddr_wid (wid),
.step_m (step_m),
.step_k (step_k),
.vld_meta_block(vld_meta_block),
.wr_en (meta_wr_en),
.wr_wid (wid),
.wr_col_idx (fmt_d),
.wr_data (meta_wr_data)
);

for (genvar i = 0; i < TCU_TC_M; ++i) begin : g_i
for (genvar j = 0; j < TCU_TC_N; ++j) begin : g_j
wire [TCU_TC_K-1:0][31:0] a_row, b_col;
wire [TCU_TC_K-1:0][31:0] a_row, b_col, b_col_dense, b_col_sparse, b_col_1, b_col_2;
for (genvar k_idx = 0; k_idx < TCU_TC_K; ++k_idx) begin : g_slice_assign
assign a_row[k_idx] = 32'(execute_if.data.rs1_data[a_off + i * TCU_TC_K + k_idx]);
assign b_col[k_idx] = 32'(execute_if.data.rs2_data[b_off + j * TCU_TC_K + k_idx]);
assign a_row[k_idx] = 32'(execute_if.data.rs1_data[a_off + i * TCU_TC_K + k_idx]);
assign b_col_dense[k_idx] = 32'(execute_if.data.rs2_data[b_off + j * TCU_TC_K + k_idx]);
assign b_col_1[k_idx] = 32'(execute_if.data.rs2_data[b_off + j * TCU_TC_K * 2 + k_idx * 2]);
assign b_col_2[k_idx] = 32'(execute_if.data.rs2_data[b_off + j * TCU_TC_K * 2 + k_idx * 2 + 1]);
end
wire [31:0] c_val = 32'(execute_if.data.rs3_data[i * TCU_TC_N + j]);
/* verilator lint_off UNUSEDSIGNAL */
wire [TCU_MAX_INPUTS-1:0] vld_mask = '1; // TODO: should connect to input source
/* verilator lint_on UNUSEDSIGNAL */
wire [META_ROW_WIDTH-1:0] vld_meta_row = vld_meta_block[META_ROW_WIDTH*i +: META_ROW_WIDTH];

VX_tcu_sel #(
.INSTANCE_ID (INSTANCE_ID),
.META_ROW_WIDTH (META_ROW_WIDTH),
.I_RATIO (I_RATIO),
.ELT_W (ELT_W)
) tcu_sel (
.b_col_1 (b_col_1),
.b_col_2 (b_col_2),
.vld_meta_row (vld_meta_row),
.b_col (b_col_sparse)
);

// Select dense or sparse B column
assign b_col = is_sparse ? b_col_sparse : b_col_dense;

wire [3:0] fmt_s_r, fmt_d_r;
wire [TCU_TC_K-1:0][31:0] a_row_r, b_col_r;
Expand Down
100 changes: 100 additions & 0 deletions hw/rtl/tcu/VX_tcu_meta.sv
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
// Copyright 2019-2023
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

`include "VX_define.vh"

/* verilator lint_off UNUSEDSIGNAL */

module VX_tcu_meta import VX_gpu_pkg::*, VX_tcu_pkg::*; #(
parameter `STRING INSTANCE_ID = "",
parameter META_BLOCK_WIDTH = 64,
parameter PER_WARP_DEPTH = 4
) (
input wire clk,
input wire reset,

// Read port (from FEDP path)
input wire [`LOG2UP(`NUM_WARPS)-1:0] raddr_wid,
input wire [3:0] step_m,
input wire [3:0] step_k,
output wire [META_BLOCK_WIDTH-1:0] vld_meta_block,

// Write port (meta_store instruction)
input wire wr_en,
input wire [`LOG2UP(`NUM_WARPS)-1:0] wr_wid,
input wire [3:0] wr_col_idx,
input wire [PER_WARP_DEPTH-1:0][31:0] wr_data
);
`UNUSED_SPARAM (INSTANCE_ID)

// Local parameters
localparam HALF_K_STEPS = TCU_K_STEPS / 2;
localparam TOTAL_DEPTH = `NUM_WARPS * PER_WARP_DEPTH;
localparam ADDRW = `CLOG2(TOTAL_DEPTH);
localparam ADDRW_PW = `CLOG2(PER_WARP_DEPTH);
localparam NUM_COLS = META_BLOCK_WIDTH / 32;

// Metadata register array (per-warp partitioned)
reg [META_BLOCK_WIDTH-1:0] meta_mem [0:TOTAL_DEPTH-1];

// Read address: bit-concatenation of step_m and step_k (pure wire routing, zero delay)
// Use generate-if to avoid zero-width bit-selects when a dimension has only 1 step
localparam M_STEP_BITS = `CLOG2(TCU_M_STEPS);
localparam K_STEP_BITS = `CLOG2(HALF_K_STEPS);

wire [ADDRW_PW-1:0] per_warp_raddr;
generate
if (K_STEP_BITS > 0 && M_STEP_BITS > 0) begin : g_addr_mk
assign per_warp_raddr = {step_m[M_STEP_BITS-1:0], step_k[K_STEP_BITS-1:0]};
end else if (K_STEP_BITS > 0) begin : g_addr_k
assign per_warp_raddr = step_k[K_STEP_BITS-1:0];
end else if (M_STEP_BITS > 0) begin : g_addr_m
assign per_warp_raddr = step_m[M_STEP_BITS-1:0];
end else begin : g_addr_zero
assign per_warp_raddr = '0;
end
endgenerate
wire [ADDRW-1:0] read_addr = {raddr_wid, per_warp_raddr};

// Combinational read
assign vld_meta_block = meta_mem[read_addr];

// Post-reset init counter: fills all warps with alternating patterns
reg [ADDRW:0] init_counter;
wire init_active = ~init_counter[ADDRW];
wire [ADDRW-1:0] init_addr = init_counter[ADDRW-1:0];
wire [META_BLOCK_WIDTH-1:0] init_data = init_addr[0] ?
{(META_BLOCK_WIDTH/4){4'b1010}} :
{(META_BLOCK_WIDTH/4){4'b0101}};

// Write logic: init or runtime meta_store
always_ff @(posedge clk) begin
if (reset) begin
init_counter <= 0;
end else if (init_active) begin
meta_mem[init_addr] <= init_data;
init_counter <= init_counter + 1;
end else if (wr_en) begin
for (int row = 0; row < PER_WARP_DEPTH; row++) begin
for (int col = 0; col < NUM_COLS; col++) begin
if (col == int'(wr_col_idx)) begin
meta_mem[{wr_wid, ADDRW_PW'(row)}][col*32 +: 32] <= wr_data[row];
end
end
end
end
end

endmodule

/* verilator lint_on UNUSEDSIGNAL */
14 changes: 11 additions & 3 deletions hw/rtl/tcu/VX_tcu_pkg.sv
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,14 @@ package VX_tcu_pkg;
localparam TCU_A_BLOCK_SIZE = TCU_TC_M * TCU_TC_K;
localparam TCU_A_SUB_BLOCKS = TCU_BLOCK_CAP / TCU_A_BLOCK_SIZE;

// B micro-tiling
// B micro-tiling (dense)
localparam TCU_B_BLOCK_SIZE = TCU_TC_K * TCU_TC_N;
localparam TCU_B_SUB_BLOCKS = TCU_BLOCK_CAP / TCU_B_BLOCK_SIZE;

// B micro-tiling (sparse 2:4)
localparam TCU_B_BLOCK_SIZE_SP = (TCU_TC_K * TCU_TC_N) * 2;
localparam TCU_B_SUB_BLOCKS_SP = TCU_BLOCK_CAP / TCU_B_BLOCK_SIZE_SP;

// Register counts
//localparam TCU_NRA = (TCU_TILE_M * TCU_TILE_K) / TCU_NT;
localparam TCU_NRB = (TCU_TILE_N * TCU_TILE_K) / TCU_NT;
Expand Down Expand Up @@ -191,13 +195,17 @@ package VX_tcu_pkg;
input op_args_t op_args
);
case (INST_TCU_BITS'(op_type))
INST_TCU_WMMA: begin
`TRACE(level, ("WMMA."));
INST_TCU_WMMA,
INST_TCU_WMMA_SP: begin
`TRACE(level, (INST_TCU_BITS'(op_type) == INST_TCU_WMMA_SP ? "WMMA_SP." : "WMMA."));
trace_fmt(level, op_args.tcu.fmt_s);
`TRACE(level, ("."));
trace_fmt(level, op_args.tcu.fmt_d);
`TRACE(level, (".%0d.%0d", op_args.tcu.step_m, op_args.tcu.step_n));
end
INST_TCU_META_STORE: begin
`TRACE(level, ("META_STORE.col%0d", op_args.tcu.fmt_d));
end
default: `TRACE(level, ("?"))
endcase
endtask
Expand Down
Loading
Loading