Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ jobs:
fail-fast: false
matrix:
version:
- '1.6'
- '1.9'
- '1'
# - 'nightly'
os:
Expand Down
7 changes: 4 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "FixedEffects"
uuid = "c8885935-8500-56a7-9867-7708b20db0eb"
version = "2.4.1"
version = "2.5.0"

[deps]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand All @@ -11,7 +11,7 @@ GroupedArrays = "6407cd72-fade-4a84-8a1e-56e431fc1533"
[compat]
StatsBase = "0.33, 0.34"
GroupedArrays = "0.3"
julia = "1.6"
julia = "1.9"

[extensions]
CUDAExt = "CUDA"
Expand All @@ -24,10 +24,11 @@ Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
[extras]
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
PooledArrays = "2dfb63ee-cc39-5dd5-95bd-886bf059d720"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["CategoricalArrays", "CUDA", "Pkg", "PooledArrays", "Test"]
test = ["CategoricalArrays", "CUDA", "Metal", "Pkg", "PooledArrays", "Test"]

8 changes: 6 additions & 2 deletions benchmarks/benchmark_Metal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,14 @@ id2 = rand(1:K, N)
fes = [FixedEffect(id1), FixedEffect(id2)]
x = Float32.(rand(N))


# here what takes time if the seocnd fixede ffects where K is very small and so there is a lot of trheads that want to write on the same thing. In that case, it would probably be good to actually pre-compute permutation for each fixed effects once, and then do as manu groups as permutations etc


# simple problem
@time solve_residuals!(deepcopy(x), fes)
@time solve_residuals!(deepcopy(x), fes; double_precision = false)
# 0.654833 seconds (1.99 k allocations: 390.841 MiB, 3.71% gc time)
@time solve_residuals!(deepcopy(x), fes; method = :Metal)
@time solve_residuals!(deepcopy(x), fes; double_precision = false, method = :Metal)
# 0.298326 seconds (129.08 k allocations: 79.208 MiB)
@time solve_residuals!([x x x x], fes)
# 1.604061 seconds (1.25 M allocations: 416.364 MiB, 4.21% gc time, 30.57% compilation time)
Expand Down
26 changes: 17 additions & 9 deletions ext/CUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,12 @@ function FixedEffects.gather!(fecoef::CuVector, refs::CuVector, α::Number, y::C
end

function gather_kernel!(fecoef, refs, α, y, cache)
index = (blockIdx().x - 1) * blockDim().x + threadIdx().x
index = (blockIdx().x - Int32(1)) * blockDim().x + threadIdx().x
stride = blockDim().x * gridDim().x
@inbounds for i = index:stride:length(y)
i = index
@inbounds while i <= length(y)
CUDA.@atomic fecoef[refs[i]] += α * y[i] * cache[i]
i += stride
end
end

Expand All @@ -65,10 +67,12 @@ function FixedEffects.scatter!(y::CuVector, α::Number, fecoef::CuVector, refs::
end

function scatter_kernel!(y, α, fecoef, refs, cache)
index = (blockIdx().x - 1) * blockDim().x + threadIdx().x
index = (blockIdx().x - Int32(1)) * blockDim().x + threadIdx().x
stride = blockDim().x * gridDim().x
@inbounds for i = index:stride:length(y)
i = index
@inbounds while i <= length(y)
y[i] += α * fecoef[refs[i]] * cache[i]
i += stride
end
end

Expand Down Expand Up @@ -124,14 +128,16 @@ function scale!(scale::CuVector, refs::CuVector, interaction::CuVector, weights:
nblocks = cld(length(refs), nthreads)
fill!(scale, 0)
@cuda threads=nthreads blocks=nblocks scale_kernel!(scale, refs, interaction, weights)
map!(x -> x > 0 ? 1 / sqrt(x) : 0, scale, scale)
map!(x -> x > 0 ? 1 / sqrt(x) : zero(eltype(scale)), scale, scale)
end

function scale_kernel!(scale, refs, interaction, weights)
index = (blockIdx().x - 1) * blockDim().x + threadIdx().x
index = (blockIdx().x - Int32(1)) * blockDim().x + threadIdx().x
stride = blockDim().x * gridDim().x
@inbounds for i = index:stride:length(interaction)
i = index
@inbounds while i <= length(interaction)
CUDA.@atomic scale[refs[i]] += abs2(interaction[i]) * weights[i]
i += stride
end
end

Expand All @@ -141,10 +147,12 @@ function cache!(cache::CuVector, refs::CuVector, interaction::CuVector, weights:
end

function cache!_kernel!(cache, refs, interaction, weights, scale)
index = (blockIdx().x - 1) * blockDim().x + threadIdx().x
index = (blockIdx().x - Int32(1)) * blockDim().x + threadIdx().x
stride = blockDim().x * gridDim().x
@inbounds for i = index:stride:length(cache)
i = index
@inbounds while i <= length(cache)
cache[i] = interaction[i] * sqrt(weights[i]) * scale[refs[i]]
i += stride
end
end

Expand Down
112 changes: 93 additions & 19 deletions ext/MetalExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,39 +34,113 @@ _mtl(T::Type, w::AbstractVector) = MtlVector{T}(convert(Vector{T}, w))
mutable struct FixedEffectLinearMapMetal{T} <: AbstractFixedEffectLinearMap{T}
fes::Vector{<:FixedEffect}
scales::Vector{<:AbstractVector}
caches::Vector{<:AbstractVector}
caches::Vector
nthreads::Int
end

function bucketize_refs(refs::Vector, K::Int, T)
if K < 100_000 && (length(refs) ÷ K >= 16)
N = length(refs)
counts = zeros(Int, K)
@inbounds for r in refs
counts[r] += 1
end
offsets = Vector{Int}(undef, K+1)
offsets[1] = 1
@inbounds for k in 1:K
offsets[k+1] = offsets[k] + counts[k]
end
next = copy(offsets[1:K]) # write heads
perm = Vector{UInt32}(undef, N)
@inbounds for i in 1:N
r = refs[i]
p = next[r]
perm[p] = i
next[r] = p + 1
end
return Metal.zeros(T, length(refs)), MtlArray(Int32.(perm)), MtlArray(Int32.(offsets))
else
return Metal.zeros(T, length(refs)), Metal.zeros(Int32, 1), Metal.zeros(Int32, 1)
end
end

function FixedEffectLinearMapMetal{T}(fes::Vector{<:FixedEffect}, nthreads) where {T}
fes = [_mtl(T, fe) for fe in fes]
fes2 = [_mtl(T, fe) for fe in fes]
scales = [Metal.zeros(T, fe.n) for fe in fes]
caches = [Metal.zeros(T, length(fes[1].interaction)) for fe in fes]
return FixedEffectLinearMapMetal{T}(fes, scales, caches, nthreads)
caches = [bucketize_refs(fe.refs, fe.n, T) for fe in fes]
return FixedEffectLinearMapMetal{T}(fes2, scales, caches, nthreads)
end

function FixedEffects.gather!(fecoef::MtlVector, refs::MtlVector, α::Number, y::MtlVector, cache::MtlVector, nthreads::Integer)
nblocks = cld(length(y), nthreads)
Metal.@sync @metal threads=nthreads groups=nblocks gather_kernel!(fecoef, refs, α, y, cache)
function FixedEffects.gather!(fecoef::MtlVector, refs::MtlVector, α::Number, y::MtlVector, cache, nthreads::Integer)
K = length(fecoef)
if K < 100_000 && (length(refs) ÷ K >= 16)
Metal.@sync @metal threads=nthreads groups=K gather_kernel_bin!(fecoef, refs, α, y, cache[1], cache[2], cache[3], Val(nthreads))
else
nblocks = cld(length(y), nthreads)
Metal.@sync @metal threads=nthreads groups=nblocks gather_kernel!(fecoef, refs, α, y, cache[1])
end
end

function gather_kernel_bin!(fecoef, refs, α, y, cache, perm, offsets, ::Val{NT}) where {NT}
k = threadgroup_position_in_grid().x # 1..K (Julia-style indexing) :contentReference[oaicite:2]{index=2}
tid = thread_position_in_threadgroup().x # 1..nthreads :contentReference[oaicite:3]{index=3}
nt = threads_per_threadgroup().x # nthreads :contentReference[oaicite:4]{index=4}

# threadgroup scratch
T = eltype(fecoef)
shared = Metal.MtlThreadGroupArray(T, NT) # threadgroup-local array :contentReference[oaicite:5]{index=5}

start = @inbounds offsets[k]
stop = @inbounds offsets[k+1] - Int32(1)

acc = zero(T)

# each thread walks its portion of the bucket
j = start + Int32(tid - 1)
while j <= stop
i = @inbounds perm[j]
@inbounds acc += (α * y[i] * cache[i])
j += Int32(nt)
end

@inbounds shared[tid] = acc
Metal.threadgroup_barrier(Metal.MemoryFlagThreadGroup) # sync + tg fence :contentReference[oaicite:6]{index=6}

# tree reduction in shared memory
offset = Int32(nt ÷ UInt32(2))
while offset > 0
if tid <= offset
@inbounds shared[tid] += shared[tid + offset]
end
Metal.threadgroup_barrier(Metal.MemoryFlagThreadGroup)
offset ÷= Int32(2)
end

# one write per coefficient (no atomics needed if groups == K and 1 group per k)
if tid == UInt32(1)
@inbounds fecoef[k] += shared[1]
end

return nothing
end

function gather_kernel!(fecoef, refs, α, y, cache)
i = thread_position_in_grid_1d()
if i <= length(refs)
Metal.atomic_fetch_add_explicit(pointer(fecoef, refs[i]), α * y[i] * cache[i])
@inbounds Metal.atomic_fetch_add_explicit(pointer(fecoef, refs[i]), α * y[i] * cache[i])
end
return nothing
end

function FixedEffects.scatter!(y::MtlVector, α::Number, fecoef::MtlVector, refs::MtlVector, cache::MtlVector, nthreads::Integer)
function FixedEffects.scatter!(y::MtlVector, α::Number, fecoef::MtlVector, refs::MtlVector, cache, nthreads::Integer)
nblocks = cld(length(y), nthreads)
Metal.@sync @metal threads=nthreads groups=nblocks scatter_kernel!(y, α, fecoef, refs, cache)
Metal.@sync @metal threads=nthreads groups=nblocks scatter_kernel!(y, α, fecoef, refs, cache[1])
end

function scatter_kernel!(y, α, fecoef, refs, cache)
i = thread_position_in_grid_1d()
if i <= length(y)
y[i] += α * fecoef[refs[i]] * cache[i]
@inbounds y[i] += α * fecoef[refs[i]] * cache[i]
end
return nothing
end
Expand Down Expand Up @@ -121,34 +195,34 @@ function scale!(scale::MtlVector, refs::MtlVector, interaction::MtlVector, weigh
nblocks = cld(length(refs), nthreads)
fill!(scale, 0)
Metal.@sync @metal threads=nthreads groups=nblocks scale_kernel!(scale, refs, interaction, weights)
Metal.@sync @metal threads=nthreads groups=nblocks inv_kernel!(scale)
Metal.@sync @metal threads=nthreads groups=nblocks inv_kernel!(scale, eltype(scale))
end

function scale_kernel!(scale, refs, interaction, weights)
i = thread_position_in_grid_1d()
if i <= length(refs)
Metal.atomic_fetch_add_explicit(pointer(scale, refs[i]), interaction[i]^2 * weights[i])
@inbounds Metal.atomic_fetch_add_explicit(pointer(scale, refs[i]), interaction[i]^2 * weights[i])
end
return nothing
end

function inv_kernel!(scale)
function inv_kernel!(scale, T)
i = thread_position_in_grid_1d()
if i <= length(scale)
scale[i] = (scale[i] > 0) ? (1 / sqrt(scale[i])) : 0.0
@inbounds scale[i] = (scale[i] > 0) ? (1 / sqrt(scale[i])) : zero(T)
end
return nothing
end

function cache!(cache::MtlVector, refs::MtlVector, interaction::MtlVector, weights::MtlVector, scale::MtlVector, nthreads::Integer)
nblocks = cld(length(cache), nthreads)
Metal.@sync @metal threads=nthreads groups=nblocks cache!_kernel!(cache, refs, interaction, weights, scale)
function cache!(cache, refs::MtlVector, interaction::MtlVector, weights::MtlVector, scale::MtlVector, nthreads::Integer)
nblocks = cld(length(cache[1]), nthreads)
Metal.@sync @metal threads=nthreads groups=nblocks cache!_kernel!(cache[1], refs, interaction, weights, scale)
end

function cache!_kernel!(cache, refs, interaction, weights, scale)
i = thread_position_in_grid_1d()
if i <= length(cache)
cache[i] = interaction[i] * sqrt(weights[i]) * scale[refs[i]]
@inbounds cache[i] = interaction[i] * sqrt(weights[i]) * scale[refs[i]]
end
return nothing
end
Expand Down
10 changes: 0 additions & 10 deletions src/SolverCPU.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,6 @@ function FixedEffectLinearMapCPU{T}(fes::Vector{<:FixedEffect}, ::Type{Val{:cpu}
return FixedEffectLinearMapCPU{T}(fes, scales, caches, nthreads)
end

function LinearAlgebra.mul!(fecoefs::FixedEffectCoefficients,
Cfem::Adjoint{T, FixedEffectLinearMapCPU{T}},
y::AbstractVector, α::Number, β::Number) where {T}
fem = adjoint(Cfem)
rmul!(fecoefs, β)
for (fecoef, fe, cache) in zip(fecoefs.x, fem.fes, fem.caches)
gather!(fecoef, fe.refs, α, y, cache, fem.nthreads)
end
return fecoefs
end

# multithreaded gather seemds to be slower
function gather!(fecoef::AbstractVector, refs::AbstractVector, α::Number,
Expand Down
9 changes: 1 addition & 8 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,7 @@
tests = ["types.jl", "solve.jl"]
println("Running tests:")

# A work around for tests to run on older versions of Julia
using Pkg
if VERSION >= v"1.8"
Pkg.add("Metal")
using Metal
end

using Test, StatsBase, CUDA, FixedEffects, PooledArrays, CategoricalArrays
using Test, StatsBase, CUDA, Metal, FixedEffects, PooledArrays, CategoricalArrays

for test in tests
try
Expand Down
8 changes: 4 additions & 4 deletions test/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,12 @@ method_s = [:cpu]
if CUDA.functional()
push!(method_s, :CUDA)
end
#if Metal.functional()
# push!(method_s, :Metal)
#end
if Metal.functional()
push!(method_s, :Metal)
end
for method in method_s
println("$method Float32")
local (r, iter, conv) = solve_residuals!(deepcopy(x),fes, method=method, double_precision = false)
local (r, iter, conv) = solve_residuals!(deepcopy(x), fes, method=method, double_precision = false)
@test Float32.(r) ≈ Float32.(r_ols)
end

Expand Down