Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "FixedEffects"
uuid = "c8885935-8500-56a7-9867-7708b20db0eb"
version = "2.5.0"
version = "2.5.1"

[deps]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand Down
1 change: 0 additions & 1 deletion benchmarks/.sublime2Terminal.jl

This file was deleted.

21 changes: 10 additions & 11 deletions benchmarks/benchmark_Metal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,29 +15,28 @@ x = Float32.(rand(N))
@time solve_residuals!(deepcopy(x), fes; double_precision = false)
# 0.654833 seconds (1.99 k allocations: 390.841 MiB, 3.71% gc time)
@time solve_residuals!(deepcopy(x), fes; double_precision = false, method = :Metal)
# 0.298326 seconds (129.08 k allocations: 79.208 MiB)
# 1.335206 seconds (3.28 M allocations: 402.660 MiB, 1.80% gc time, 123.64% compilation time: <1% of which was recompilation)
@time solve_residuals!([x x x x], fes)
# 1.604061 seconds (1.25 M allocations: 416.364 MiB, 4.21% gc time, 30.57% compilation time)
# 1.886616 seconds (3.60 M allocations: 731.777 MiB, 1.34% gc time, 122.12% compilation time: 5% of which was recompilation)
@time solve_residuals!([x x x x], fes; method = :Metal)
# 0.790909 seconds (531.78 k allocations: 204.363 MiB, 3.19% compilation time)
# 1.421205 seconds (2.78 M allocations: 497.846 MiB, 1.64% gc time, 110.87% compilation time: <1% of which was recompilation)



# More complicated problem
N = 800000 # number of observations
M = 400000 # number of workers
O = 50000 # number of firms
N = 8000000 # number of observations
M = 4000000 # number of workers
O = 500000 # number of firms
Random.seed!(1234)
pid = rand(1:M, N)
fid = [rand(max(1, div(x, 8)-10):min(O, div(x, 8)+10)) for x in pid]
x = rand(N)
fes = [FixedEffect(pid), FixedEffect(fid)]


@time solve_residuals!([x x x x], fes; double_precision = false)
# 8.294446 seconds (225.13 k allocations: 67.777 MiB, 0.11% gc time)

@time solve_residuals!([x x x x], fes; double_precision = false, method = :Metal)
# 1.605953 seconds (3.25 M allocations: 103.342 MiB, 1.82% gc time)
@time solve_residuals!([x x x x], fes; double_precision = false, maxiter = 100)
# 36.554763 seconds (98.71 M allocations: 5.253 GiB, 1.11% gc time, 114.45% compilation time: 7% of which was recompilation)
@time solve_residuals!([x x x x], fes; double_precision = false, method = :Metal, maxiter = 100)
# 20.652590 seconds (79.33 M allocations: 4.114 GiB, 0.75% gc time, 162.10% compilation time: <1% of which was recompilation)


35 changes: 18 additions & 17 deletions ext/MetalExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,27 +40,28 @@ end

function bucketize_refs(refs::Vector, K::Int, T)
if K < 100_000 && (length(refs) ÷ K >= 16)
N = length(refs)
counts = zeros(Int, K)
# count the number of obs per group
counts = zeros(UInt32, K)
@inbounds for r in refs
counts[r] += 1
counts[r] += 0x00000001
end
offsets = Vector{Int}(undef, K+1)
offsets[1] = 1
# offsets is vcat(1, cumsum(counts))
offsets = Vector{UInt32}(undef, K+1)
offsets[1] = 0x00000001
@inbounds for k in 1:K
offsets[k+1] = offsets[k] + counts[k]
end
next = copy(offsets[1:K]) # write heads
perm = Vector{UInt32}(undef, N)
@inbounds for i in 1:N
next = offsets[1:K]
perm = Vector{UInt32}(undef, length(refs))
@inbounds for i in eachindex(refs)
r = refs[i]
p = next[r]
perm[p] = i
next[r] = p + 1
perm[p] = UInt32(i)
next[r] = p + 0x00000001
end
return Metal.zeros(T, length(refs)), MtlArray(Int32.(perm)), MtlArray(Int32.(offsets))
return Metal.zeros(T, length(refs)), MtlArray(UInt32.(perm)), MtlArray(UInt32.(offsets))
else
return Metal.zeros(T, length(refs)), Metal.zeros(Int32, 1), Metal.zeros(Int32, 1)
return Metal.zeros(T, length(refs)), Metal.zeros(UInt32, 1), Metal.zeros(UInt32, 1)
end
end

Expand Down Expand Up @@ -96,28 +97,28 @@ function gather_kernel_bin!(fecoef, refs, α, y, cache, perm, offsets, ::Val{NT}
acc = zero(T)

# each thread walks its portion of the bucket
j = start + Int32(tid - 1)
j = start + UInt32(tid - 1)
while j <= stop
i = @inbounds perm[j]
@inbounds acc += (α * y[i] * cache[i])
j += Int32(nt)
j += UInt32(nt)
end

@inbounds shared[tid] = acc
Metal.threadgroup_barrier(Metal.MemoryFlagThreadGroup) # sync + tg fence :contentReference[oaicite:6]{index=6}

# tree reduction in shared memory
offset = Int32(nt ÷ UInt32(2))
offset = UInt32(nt ÷ 0x00000002)
while offset > 0
if tid <= offset
@inbounds shared[tid] += shared[tid + offset]
end
Metal.threadgroup_barrier(Metal.MemoryFlagThreadGroup)
offset ÷= Int32(2)
offset ÷= 0x00000002
end

# one write per coefficient (no atomics needed if groups == K and 1 group per k)
if tid == UInt32(1)
if tid == 0x00000001
@inbounds fecoef[k] += shared[1]
end

Expand Down
7 changes: 4 additions & 3 deletions src/SolverCPU.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,13 +100,14 @@ function update_weights!(feM::FixedEffectSolverCPU, weights::AbstractWeights)
end

function scale!(scale::AbstractVector, refs::AbstractVector, interaction::AbstractVector, weights::AbstractVector)
fill!(scale, 0)
fill!(scale, 0)
@fastmath @inbounds @simd for i in eachindex(refs)
scale[refs[i]] += abs2(interaction[i]) * weights[i]
end
# Case of interaction variatble equal to zero in the category (issue #97)
for i in 1:length(scale)
scale[i] = scale[i] > 0 ? (1 / sqrt(scale[i])) : 0.0
T = eltype(scale)
@fastmath @inbounds @simd for i in eachindex(scale)
scale[i] = scale[i] > 0 ? (1 / sqrt(scale[i])) : zero(T)
end
end

Expand Down