Skip to content

Add Fused Multi-Head Attention example#16

Open
AntonOresten wants to merge 10 commits intoJuliaGPU:mainfrom
AntonOresten:fmha
Open

Add Fused Multi-Head Attention example#16
AntonOresten wants to merge 10 commits intoJuliaGPU:mainfrom
AntonOresten:fmha

Conversation

@AntonOresten
Copy link
Contributor

@AntonOresten AntonOresten commented Jan 10, 2026

See outdated

Seems to fall slightly short of my NNop / ONIONop baseline (no WMMA), although I haven't compared it to the Python version. On my GPU, it compiles and runs fastest with tile_n=32 and tile_m=32:

julia> begin
           T = Float32
           D, QL, KL, H, B = 64, 4096, 4096, 4, 4
           q = CUDA.randn(T, D, QL, H, B)
           k = CUDA.randn(T, D, KL, H, B)  
           v = CUDA.randn(T, D, KL, H, B)
       end;

julia> @b CUDA.@sync ONIONop.flash_attention(q, k, v, causal=false)
9.559 ms (339 allocs: 7.875 KiB)

julia> @b CUDA.@sync cutile_fmha(q, k, v, causal=false, tile_m=32, tile_n=32)
11.058 ms (540 allocs: 23.109 KiB)

EDIT: this is without tensor cores. simply switching the compute type to TFloat32 / BFloat16 and exploring the optimization and entry hint landscape makes forward and backward passes ~10x faster.

Notably, cutile-python has a latency argument for ct.load, as well as num_ctas and occupancy arguments for the kernel, which might affect performance. The python version also does a kernel config autotune by searching a space of hand-picked configurations. EDIT: fixed in #32 and #27.

Another thing that might be important for correctness or covering edge cases is exposing flush_to_zero? Used in e.g. exp2. Haven't thought about in which cases this matters.

@AntonOresten
Copy link
Contributor Author

AntonOresten commented Jan 17, 2026

Seeing some weird erroring when branching (being fixed in #53):

Click to see snippets

This works:

        qk = if !EVEN_K[] && j >= mask_start
            offs_n = ((j-Int32(1)) * TILE_N[]) .+ offs_n_tile
            mask = ct.full((TILE_N[], TILE_M[]), true, Bool)
            mask = mask .& (offs_n .<= k_seqlen)
            mask = ct.where(mask, ct.zeros((TILE_N[], TILE_M[],), Float32), ct.full((TILE_N[], TILE_M[],), -Inf32, Float32))
            qk .+ mask
        else
            qk
        end

but this doesn't:

        if !EVEN_K[] && j >= mask_start
            offs_n = ((j-Int32(1)) * TILE_N[]) .+ offs_n_tile
            mask = ct.full((TILE_N[], TILE_M[]), true, Bool)
            mask = mask .& (offs_n .<= k_seqlen)
            mask = ct.where(mask, ct.zeros((TILE_N[], TILE_M[],), Float32), ct.full((TILE_N[], TILE_M[],), -Inf32, Float32))
            qk = qk .+ mask
        end

nor does this:

        qk = if !EVEN_K[] && j >= mask_start
            offs_n = ((j-Int32(1)) * TILE_N[]) .+ offs_n_tile
            mask = ct.full((TILE_N[], TILE_M[]), true, Bool)
            if !EVEN_K[]
                mask .& (offs_n .<= k_seqlen)
            end
            mask = ct.where(mask, ct.zeros((TILE_N[], TILE_M[],), Float32), ct.full((TILE_N[], TILE_M[],), -Inf32, Float32))
            qk .+ mask
        else
            qk
        end

In the second and third block, I get "ERROR: SSAValue %___ not found in context"

after removing the second condition, I can suddenly have a nested if block, and I don't need the outer else block:

        if !EVEN_K[]
            offs_n = ((j-Int32(1)) * TILE_N[]) .+ offs_n_tile
            mask = ct.full((TILE_N[], TILE_M[]), true, Bool)
            if !EVEN_K[]
                mask = mask .& (offs_n .<= k_seqlen)
            end
            mask = ct.where(mask, ct.zeros((TILE_N[], TILE_M[],), Float32), ct.full((TILE_N[], TILE_M[],), -Inf32, Float32))
            qk = qk .+ mask
        end

Does the if block need to depend on compile time constants?

I'd need this to make the padding and causal mask properly.

@maleadt
Copy link
Member

maleadt commented Jan 19, 2026

In the second and third block, I get "ERROR: SSAValue %___ not found in context"

That's an IRStructurizer error. Can you provide an MWE?

@AntonOresten
Copy link
Contributor Author

That's an IRStructurizer error. Can you provide an MWE?

I was able to narrow it down and believe it is covered by #53. See the added tests for MWE.

With #51 and #53, I now have forward and backward passes working locally!

@AntonOresten AntonOresten marked this pull request as ready for review February 5, 2026 18:28
@AntonOresten
Copy link
Contributor Author

AntonOresten commented Feb 5, 2026

Currently needing to wrap outside Float32 constants in Float32 within the kernel because MulF otherwise sees it as nothing:

qk_scale = Float32(qk_scale) * Float32(INV_LOG_2)

@AntonOresten
Copy link
Contributor Author

Another concern is whether I should convert to Int32 or not, essentially every time I do index arithmetic.

@maleadt
Copy link
Member

maleadt commented Feb 6, 2026

Currently needing to wrap outside Float32 constants in Float32 within the kernel because MulF otherwise sees it as nothing:

qk_scale = Float32(qk_scale) * Float32(INV_LOG_2)

Can you elaborate?

Another concern is whether I should convert to Int32 or not, essentially every time I do index arithmetic.

Yeah that's a common Julia pain point. It's why we have One(), and in CUDA.jl you can do e.g. 1i32.

@AntonOresten
Copy link
Contributor Author

Can you elaborate?

I define const INV_LOG_2 = Float32(1 / log(2)). If I use it without wrapping in Float32 within the kernel I get:

ERROR: LoadError: MethodError: no method matching encode_MulFOp!(::cuTile.CodeBuilder, ::cuTile.TypeId, ::cuTile.Value, ::Nothing)
The function `encode_MulFOp!` exists, but no method is defined for this combination of argument types.

Closest candidates are:
  encode_MulFOp!(::cuTile.CodeBuilder, ::cuTile.TypeId, ::cuTile.Value, ::cuTile.Value; rounding_mode, flush_to_zero)
   @ cuTile ~/.julia/dev/cuTile/src/bytecode/encodings.jl:720

in CUDA.jl you can do e.g. 1i32.

Oh, neat. I didn't know. I considered maybe a @32 macro (macro var"32" ... end) to make all 64-bit integers wrapped in their 32-bit counterparts. Found that it wouldn't work within curly brackets like for type parameters since e.g. Array{T,Int32(1)} won't count as a vector, but the macro doesn't need to descend into :curly expressions. Problem is still that some functions actually only have methods for Int so it can't be applied to the entire function.

@maleadt
Copy link
Member

maleadt commented Feb 6, 2026

Problem is still that some functions actually only have methods for Int so it can't be applied to the entire function.

In general, Julia's array indexing requires Int. In CUDA.jl we've added some additional methods to override part of the getindex chain to support Int32, but it's tricky...

@maleadt
Copy link
Member

maleadt commented Feb 8, 2026

Constants should work without the type conversion now.

@AntonOresten
Copy link
Contributor Author

See #77 (comment)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants