Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions docs/wave/shared_memory.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,15 @@ The goal is to assign a starting memory offset to each allocation such that the

The heuristic provides a fast, approximate solution to this problem. It does not guarantee optimality but often performs well in practice.

Performance Implications & Usage Guidelines
-------------------------------------------

This optimization works by merging multiple distinct `memref.alloc` operations into a single large allocation with views. While this can reduce peak memory usage, it has a significant side effect: **it obscures aliasing information.**

* **The Mechanism:** The ``AMDGPULowerModuleLDS`` pass relies on distinct allocations to generate precise aliasing metadata. When allocations are merged, this metadata is lost.
* **The Consequence:** Without precise aliasing info, the ``SIInsertWaitcnts`` pass conservatively inserts synchronization barriers (e.g., ``s_waitcnt vmcnt(0)``) to prevent potential data hazards. This breaks software pipelining by preventing the intended overlap of global memory loads with computation.
* **When to use:** This optimization is beneficial for kernels with **disjoint buffer lifecycles** (e.g., Extended Attention). In these cases, it allows memory buffers to be reused across time.

The Allocation Data
--------------------

Expand Down
23 changes: 9 additions & 14 deletions lit_tests/kernel/wave/attention/extend_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,8 @@ def test_extend_attention():
# CHECK-DAG: memref.reinterpret_cast %{{.*}} to offset: [0], sizes: [%{{.*}}, 16, 64], strides: [1024, 64, 1] : memref<f16> to memref<?x16x64xf16, strided<[1024, 64, 1]>>
# CHECK-DAG: memref.reinterpret_cast %{{.*}} to offset: [0], sizes: [%{{.*}}, 4, 64], strides: [256, 64, 1] : memref<f16> to memref<?x4x64xf16, strided<[256, 64, 1]>>
# CHECK-DAG: memref.reinterpret_cast %{{.*}} to offset: [0], sizes: [%{{.*}}, 4, 64], strides: [256, 64, 1] : memref<f16> to memref<?x4x64xf16, strided<[256, 64, 1]>>
# CHECK-DAG: %[[C4352:.*]] = arith.constant 4352 : index
# CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
# CHECK-DAG: %[[ALLOC0:.*]] = memref.alloc() : memref<8704xi8, #gpu.address_space<workgroup>>
# CHECK-DAG: %[[ALLOC1:.*]] = memref.view %[[ALLOC0]][%[[C0]]][] : memref<8704xi8, #gpu.address_space<workgroup>> to memref<32x1x68xf16, #gpu.address_space<workgroup>>
# CHECK-DAG: %[[ALLOC2:.*]] = memref.view %[[ALLOC0]][%[[C4352]]][] : memref<8704xi8, #gpu.address_space<workgroup>> to memref<1x32x68xf16, #gpu.address_space<workgroup>>
# CHECK-DAG: %[[ALLOC1:.*]] = memref.alloc() : memref<32x1x68xf16, #gpu.address_space<workgroup>>
# CHECK-DAG: %[[ALLOC2:.*]] = memref.alloc() : memref<1x32x68xf16, #gpu.address_space<workgroup>>
# CHECK-COUNT-4: vector.maskedload
# CHECK: scf.for
# 3 masked load for sequence idx, 2 for k_cache, and 1 for v_cache.
Expand All @@ -102,9 +99,9 @@ def test_extend_attention():
# CHECK-NOT: amdgpu.lds_barrier
# CHECK: scf.for
# CHECK-COUNT-1: vector.maskedload
# CHECK-COUNT-1: vector.store %{{.*}}, %[[ALLOC2]]
# CHECK-COUNT-1: vector.store %{{.*}}, {{.*}}memref<1x32x68xf16
# CHECK-COUNT-32: memref.load %{{.*}}
# CHECK-COUNT-8: vector.load %[[ALLOC2]]
# CHECK-COUNT-8: vector.load {{.*}}memref<1x32x68xf16
# CHECK-COUNT-8: amdgpu.mfma
# CHECK-COUNT-2: arith.cmpi slt
# CHECK-COUNT-2: arith.select
Expand Down Expand Up @@ -171,11 +168,8 @@ def test_causal_extend_attention():
# CHECK-DAG: memref.reinterpret_cast %{{.*}} to offset: [0], sizes: [%{{.*}}, 16, 64], strides: [1024, 64, 1] : memref<f16> to memref<?x16x64xf16, strided<[1024, 64, 1]>>
# CHECK-DAG: memref.reinterpret_cast %{{.*}} to offset: [0], sizes: [%{{.*}}, 4, 64], strides: [256, 64, 1] : memref<f16> to memref<?x4x64xf16, strided<[256, 64, 1]>>
# CHECK-DAG: memref.reinterpret_cast %{{.*}} to offset: [0], sizes: [%{{.*}}, 4, 64], strides: [256, 64, 1] : memref<f16> to memref<?x4x64xf16, strided<[256, 64, 1]>>
# CHECK-DAG: %[[C4352:.*]] = arith.constant 4352 : index
# CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
# CHECK-DAG: %[[ALLOC0:.*]] = memref.alloc() : memref<8704xi8, #gpu.address_space<workgroup>>
# CHECK-DAG: %[[ALLOC1:.*]] = memref.view %[[ALLOC0]][%[[C0]]][] : memref<8704xi8, #gpu.address_space<workgroup>> to memref<32x1x68xf16, #gpu.address_space<workgroup>>
# CHECK-DAG: %[[ALLOC2:.*]] = memref.view %[[ALLOC0]][%[[C4352]]][] : memref<8704xi8, #gpu.address_space<workgroup>> to memref<1x32x68xf16, #gpu.address_space<workgroup>>
# CHECK-DAG: %[[ALLOC1:.*]] = memref.alloc() : memref<32x1x68xf16, #gpu.address_space<workgroup>>
# CHECK-DAG: %[[ALLOC2:.*]] = memref.alloc() : memref<1x32x68xf16, #gpu.address_space<workgroup>>
# CHECK-COUNT-4: vector.maskedload
# CHECK: scf.for
# 3 masked load for sequence idx, 2 for k_cache, and 1 for v_cache.
Expand Down Expand Up @@ -223,9 +217,10 @@ def test_causal_extend_attention():

# CHECK: scf.for
# CHECK-COUNT-1: vector.maskedload
# CHECK-COUNT-1: vector.store %{{.*}}, %[[ALLOC2]]
# CHECK-COUNT-1: vector.store %{{.*}}, {{.*}}memref<1x32x68xf16
# CHECK-COUNT-1: vector.store %{{.*}}, {{.*}}memref<32x1x68xf16
# CHECK-COUNT-32: memref.load %{{.*}}
# CHECK-COUNT-8: vector.load %[[ALLOC2]]
# CHECK-COUNT-8: vector.load {{.*}}memref<1x32x68xf16
# CHECK-COUNT-8: amdgpu.mfma

# softcap/logitcap modifier:
Expand Down
29 changes: 14 additions & 15 deletions lit_tests/kernel/wave/barriers.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,31 +407,30 @@ def test_bshd_attention_pipelined():
print(base_attention.asm)

# CHECK-LABEL: func.func @base_attention
# CHECK-DAG: memref.alloc()
# CHECK-DAG: %[[V0:.*]] = memref.view {{.*}} #gpu.address_space<workgroup>>
# CHECK-DAG: %[[V1:.*]] = memref.view {{.*}} #gpu.address_space<workgroup>>
# CHECK-DAG: %[[ALLOC0:.*]] = memref.alloc(){{.*}} #gpu.address_space<workgroup>>
# CHECK-DAG: %[[ALLOC1:.*]] = memref.alloc(){{.*}} #gpu.address_space<workgroup>>
# CHECK-DAG: vector.load

# CHECK-DAG: vector.store {{.*}} %[[V1]]{{.*}} #gpu.address_space<workgroup>>
# CHECK-DAG: vector.store {{.*}} %[[ALLOC1]]{{.*}} #gpu.address_space<workgroup>>
# CHECK-DAG: rocdl.s.wait.dscnt 0
# CHECK-DAG: rocdl.s.barrier.signal -1

# CHECK-DAG: vector.extract

# CHECK-DAG: rocdl.s.barrier.wait -1
# CHECK-DAG: vector.load %[[V1]]{{.*}} #gpu.address_space<workgroup>>
# CHECK-DAG: vector.load %[[ALLOC1]]{{.*}} #gpu.address_space<workgroup>>

### loads and stores are operating on differnt parts of shared buffers -> no barriers need to be inserted here.

# CHECK-DAG: vector.store {{.*}} %[[V0]]{{.*}} #gpu.address_space<workgroup>>
# CHECK-DAG: vector.store {{.*}} %[[ALLOC0]]{{.*}} #gpu.address_space<workgroup>>
# CHECK-DAG: rocdl.s.wait.dscnt 0
# CHECK-DAG: rocdl.s.barrier.signal -1
# CHECK-DAG: rocdl.s.barrier.signal id = -1

# CHECK-DAG: amdgpu.wmma

# CHECK-DAG: rocdl.s.barrier.wait -1
# CHECK-DAG: vector.load %[[V0]]{{.*}} #gpu.address_space<workgroup>>
# CHECK-DAG: vector.load %[[V1]]{{.*}} #gpu.address_space<workgroup>>
# CHECK-DAG: vector.load %[[ALLOC0]]{{.*}} #gpu.address_space<workgroup>>
# CHECK-DAG: vector.load %[[ALLOC1]]{{.*}} #gpu.address_space<workgroup>>
# CHECK-DAG: amdgpu.wmma

# CHECK: scf.for
Expand All @@ -441,7 +440,7 @@ def test_bshd_attention_pipelined():
# CHECK-DAG: rocdl.s.wait.dscnt 0
# CHECK-DAG: rocdl.s.barrier.signal -1
# CHECK-DAG: rocdl.s.barrier.wait -1
# CHECK-DAG: vector.store {{.*}} %[[V1]]{{.*}} #gpu.address_space<workgroup>>
# CHECK-DAG: vector.store {{.*}} %[[ALLOC1]]{{.*}} #gpu.address_space<workgroup>>

### signal write to buffer 1 completes.

Expand All @@ -451,8 +450,8 @@ def test_bshd_attention_pipelined():
# CHECK-DAG: vector.load

# CHECK-DAG: rocdl.s.barrier.wait -1
# CHECK-DAG: vector.load %[[V1]]{{.*}} #gpu.address_space<workgroup>>
# CHECK-DAG: vector.store {{.*}} %[[V0]]{{.*}} #gpu.address_space<workgroup>>
# CHECK-DAG: vector.load %[[ALLOC1]]{{.*}} #gpu.address_space<workgroup>>
# CHECK-DAG: vector.store {{.*}} %[[ALLOC0]]{{.*}} #gpu.address_space<workgroup>>

### signal here represents 2 things: read from buffer 1 completes, write to buffer 0 completes.

Expand All @@ -464,10 +463,10 @@ def test_bshd_attention_pipelined():
### wait here then waits for read from buffer 1 completes, write to buffer 0 completes.

# CHECK-DAG: rocdl.s.barrier.wait -1
# CHECK-DAG: vector.load %[[V0]]{{.*}} #gpu.address_space<workgroup>>
# CHECK-DAG: vector.load %[[V1]]{{.*}} #gpu.address_space<workgroup>>
# CHECK-DAG: vector.load %[[ALLOC0]]{{.*}} #gpu.address_space<workgroup>>
# CHECK-DAG: vector.load %[[ALLOC1]]{{.*}} #gpu.address_space<workgroup>>
# CHECK-DAG: amdgpu.wmma
# CHECK-DAG: vector.load %[[V0]]{{.*}} #gpu.address_space<workgroup>>
# CHECK-DAG: vector.load %[[ALLOC0]]{{.*}} #gpu.address_space<workgroup>>

# CHECK-DAG: arith.maximumf
# CHECK-DAG: amdgpu.wmma
Expand Down
9 changes: 5 additions & 4 deletions lit_tests/kernel/wave/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -1901,7 +1901,6 @@ def test_block_reduce_sum(
# CHECK-DAG: #[[map2:.+]] = affine_map<()[s0] -> ((s0 floordiv 64) mod 4)>
# CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index
# CHECK: %[[tid:.+]] = gpu.thread_id x
# CHECK: %[[alloc:.+]] = memref.alloc() : memref<8xi8, #gpu.address_space<workgroup>>

# Local Reduce
# CHECK-COUNT-2: vector.extract
Expand All @@ -1911,16 +1910,18 @@ def test_block_reduce_sum(
# CHECK-COUNT-6: gpu.shuffle xor
# CHECK-NEXT: %[[global_reduce:.+]] = arith.addf

# Allocate shared memory for partial wave results
# CHECK: %[[alloc:.+]] = memref.alloc() : memref<4xf16, #gpu.address_space<workgroup>>

# Write partial wave result into shared memory to be accessible by other waves.
# CHECK: %[[view:.+]] = memref.view %[[alloc]][%[[c0]]][] : memref<8xi8, #gpu.address_space<workgroup>> to memref<4xf16, #gpu.address_space<workgroup>>
# CHECK: scf.if {{.*}} {
# CHECK: %[[wave_id:.+]] = affine.apply #[[map2]]()[%[[tid]]]
# CHECK: vector.store %[[global_reduce]], %[[view]][%[[wave_id]]] : memref<4xf16, #gpu.address_space<workgroup>>, vector<1xf16>
# CHECK: vector.store %[[global_reduce]], %[[alloc]][%[[wave_id]]] : memref<4xf16, #gpu.address_space<workgroup>>, vector<1xf16>
# CHECK: }
# CHECK-NEXT: amdgpu.lds_barrier

# Get all partial wave results and locally reduce
# CHECK: %[[wave_res:.+]] = vector.load %[[view]]
# CHECK: %[[wave_res:.+]] = vector.load %[[alloc]]
# CHECK: vector.extract %[[wave_res]][0] : f16 from vector<4xf16>
# CHECK: vector.extract %[[wave_res]][1] : f16 from vector<4xf16>
# CHECK-NEXT: arith.addf
Expand Down
Loading
Loading