Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion benchmarks/.sublime2Terminal.jl

This file was deleted.

21 changes: 10 additions & 11 deletions benchmarks/benchmark_Metal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,29 +15,28 @@ x = Float32.(rand(N))
@time solve_residuals!(deepcopy(x), fes; double_precision = false)
# 0.654833 seconds (1.99 k allocations: 390.841 MiB, 3.71% gc time)
@time solve_residuals!(deepcopy(x), fes; double_precision = false, method = :Metal)
# 0.298326 seconds (129.08 k allocations: 79.208 MiB)
# 1.335206 seconds (3.28 M allocations: 402.660 MiB, 1.80% gc time, 123.64% compilation time: <1% of which was recompilation)
@time solve_residuals!([x x x x], fes)
# 1.604061 seconds (1.25 M allocations: 416.364 MiB, 4.21% gc time, 30.57% compilation time)
# 1.886616 seconds (3.60 M allocations: 731.777 MiB, 1.34% gc time, 122.12% compilation time: 5% of which was recompilation)
@time solve_residuals!([x x x x], fes; method = :Metal)
# 0.790909 seconds (531.78 k allocations: 204.363 MiB, 3.19% compilation time)
# 1.421205 seconds (2.78 M allocations: 497.846 MiB, 1.64% gc time, 110.87% compilation time: <1% of which was recompilation)



# More complicated problem
N = 800000 # number of observations
M = 400000 # number of workers
O = 50000 # number of firms
N = 8000000 # number of observations
M = 4000000 # number of workers
O = 500000 # number of firms
Random.seed!(1234)
pid = rand(1:M, N)
fid = [rand(max(1, div(x, 8)-10):min(O, div(x, 8)+10)) for x in pid]
x = rand(N)
fes = [FixedEffect(pid), FixedEffect(fid)]


@time solve_residuals!([x x x x], fes; double_precision = false)
# 8.294446 seconds (225.13 k allocations: 67.777 MiB, 0.11% gc time)

@time solve_residuals!([x x x x], fes; double_precision = false, method = :Metal)
# 1.605953 seconds (3.25 M allocations: 103.342 MiB, 1.82% gc time)
@time solve_residuals!([x x x x], fes; double_precision = false, maxiter = 100)
# 36.554763 seconds (98.71 M allocations: 5.253 GiB, 1.11% gc time, 114.45% compilation time: 7% of which was recompilation)
@time solve_residuals!([x x x x], fes; double_precision = false, method = :Metal, maxiter = 100)
# 20.652590 seconds (79.33 M allocations: 4.114 GiB, 0.75% gc time, 162.10% compilation time: <1% of which was recompilation)


7 changes: 4 additions & 3 deletions src/SolverCPU.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,13 +100,14 @@ function update_weights!(feM::FixedEffectSolverCPU, weights::AbstractWeights)
end

function scale!(scale::AbstractVector, refs::AbstractVector, interaction::AbstractVector, weights::AbstractVector)
fill!(scale, 0)
fill!(scale, 0)
@fastmath @inbounds @simd for i in eachindex(refs)
scale[refs[i]] += abs2(interaction[i]) * weights[i]
end
# Case of interaction variatble equal to zero in the category (issue #97)
for i in 1:length(scale)
scale[i] = scale[i] > 0 ? (1 / sqrt(scale[i])) : 0.0
T = eltype(scale)
@fastmath @inbounds @simd for i in eachindex(scale)
scale[i] = scale[i] > 0 ? (1 / sqrt(scale[i])) : zero(T)
end
end

Expand Down