Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "FixedEffects"
uuid = "c8885935-8500-56a7-9867-7708b20db0eb"
version = "2.5.1"
version = "2.5.2"

[deps]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand Down
69 changes: 37 additions & 32 deletions ext/MetalExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,44 +38,49 @@ mutable struct FixedEffectLinearMapMetal{T} <: AbstractFixedEffectLinearMap{T}
nthreads::Int
end

function bucketize_refs(refs::Vector, K::Int, T)
if K < 100_000 && (length(refs) ÷ K >= 16)
# count the number of obs per group
counts = zeros(UInt32, K)
@inbounds for r in refs
counts[r] += 0x00000001
end
# offsets is vcat(1, cumsum(counts))
offsets = Vector{UInt32}(undef, K+1)
offsets[1] = 0x00000001
@inbounds for k in 1:K
offsets[k+1] = offsets[k] + counts[k]
end
next = offsets[1:K]
perm = Vector{UInt32}(undef, length(refs))
@inbounds for i in eachindex(refs)
r = refs[i]
p = next[r]
perm[p] = UInt32(i)
next[r] = p + 0x00000001
end
return Metal.zeros(T, length(refs)), MtlArray(UInt32.(perm)), MtlArray(UInt32.(offsets))
else
return Metal.zeros(T, length(refs)), Metal.zeros(UInt32, 1), Metal.zeros(UInt32, 1)
end
function bucketize_refs(refs::Vector, n::Int)
# count the number of obs per group
counts = zeros(UInt32, n)
@inbounds for r in refs
counts[r] += 0x00000001
end
# offsets is vcat(1, cumsum(counts))
offsets = Vector{UInt32}(undef, n + 1)
offsets[1] = 0x00000001
@inbounds for k in 1:n
offsets[k+1] = offsets[k] + counts[k]
end
next = offsets[1:n]
perm = Vector{UInt32}(undef, length(refs))
@inbounds for i in eachindex(refs)
r = refs[i]
p = next[r]
perm[p] = UInt32(i)
next[r] = p + 0x00000001
end
return perm, offsets
end

function FixedEffectLinearMapMetal{T}(fes::Vector{<:FixedEffect}, nthreads) where {T}
fes2 = [_mtl(T, fe) for fe in fes]
scales = [Metal.zeros(T, fe.n) for fe in fes]
caches = [bucketize_refs(fe.refs, fe.n, T) for fe in fes]
caches = [[Metal.zeros(T, length(fe.refs)), Metal.zeros(UInt32, 1), Metal.zeros(UInt32, 1)] for fe in fes]
Threads.@threads for i in 1:length(fes)
refs = fes[i].refs
n = fes[i].n
if n < min(100_000, div(length(refs), 16))
out = bucketize_refs(refs, n)
caches[i][2] = MtlArray(out[1])
caches[i][3] = MtlArray(out[2])
end
end
return FixedEffectLinearMapMetal{T}(fes2, scales, caches, nthreads)
end

function FixedEffects.gather!(fecoef::MtlVector, refs::MtlVector, α::Number, y::MtlVector, cache, nthreads::Integer)
K = length(fecoef)
if K < 100_000 && (length(refs) ÷ K >= 16)
Metal.@sync @metal threads=nthreads groups=K gather_kernel_bin!(fecoef, refs, α, y, cache[1], cache[2], cache[3], Val(nthreads))
function FixedEffects.gather!(fecoef::MtlVector, refs::MtlVector, α::Number, y::MtlVector, cache::Vector, nthreads::Integer)
n = length(fecoef)
if n < min(100_000, div(length(refs), 16))
Metal.@sync @metal threads=nthreads groups=n gather_kernel_bin!(fecoef, refs, α, y, cache[1], cache[2], cache[3], Val(nthreads))
else
nblocks = cld(length(y), nthreads)
Metal.@sync @metal threads=nthreads groups=nblocks gather_kernel!(fecoef, refs, α, y, cache[1])
Expand Down Expand Up @@ -117,7 +122,7 @@ function gather_kernel_bin!(fecoef, refs, α, y, cache, perm, offsets, ::Val{NT}
offset ÷= 0x00000002
end

# one write per coefficient (no atomics needed if groups == K and 1 group per k)
# one write per coefficient (no atomics needed if groups == n and 1 group per k)
if tid == 0x00000001
@inbounds fecoef[k] += shared[1]
end
Expand All @@ -133,7 +138,7 @@ function gather_kernel!(fecoef, refs, α, y, cache)
return nothing
end

function FixedEffects.scatter!(y::MtlVector, α::Number, fecoef::MtlVector, refs::MtlVector, cache, nthreads::Integer)
function FixedEffects.scatter!(y::MtlVector, α::Number, fecoef::MtlVector, refs::MtlVector, cache::Vector, nthreads::Integer)
nblocks = cld(length(y), nthreads)
Metal.@sync @metal threads=nthreads groups=nblocks scatter_kernel!(y, α, fecoef, refs, cache[1])
end
Expand Down