Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -550,3 +550,140 @@ function DI.value_gradient_and_hessian!(
)
return fc(x), grad, hess
end

## HVP

struct FiniteDiffHVPPrep{SIG, C1, C2, RG, AG, RH, AH, H} <: DI.HVPPrep{SIG}
_sig::Val{SIG}
gradient_cache::C1
hessian_cache::C2
relstep_g::RG
absstep_g::AG
relstep_h::RH
absstep_h::AH
hess::H
end

function DI.prepare_hvp_nokwarg(
strict::Val, f, backend::AutoFiniteDiff, x, tx::NTuple, contexts::Vararg{DI.Context, C}
) where {C}
_sig = DI.signature(f, backend, x, tx, contexts...; strict)
fc = DI.fix_tail(f, map(DI.unwrap, contexts)...)
y = fc(x)
df = zero(y) .* x
gradient_cache = GradientCache(df, x, fdtype(backend))
hessian_cache = HessianCache(x, fdhtype(backend))
relstep_g = if isnothing(backend.relstep)
default_relstep(fdtype(backend), eltype(x))
else
backend.relstep
end
relstep_h = if isnothing(backend.relstep)
default_relstep(fdhtype(backend), eltype(x))
else
backend.relstep
end
absstep_g = if isnothing(backend.absstep)
relstep_g
else
backend.absstep
end
absstep_h = if isnothing(backend.absstep)
relstep_h
else
backend.absstep
end
hess = similar(x, eltype(x), (length(x), length(x)))
return FiniteDiffHVPPrep(
_sig, gradient_cache, hessian_cache, relstep_g, absstep_g, relstep_h, absstep_h, hess
)
end

function DI.hvp(
f,
prep::FiniteDiffHVPPrep,
backend::AutoFiniteDiff,
x,
tx::NTuple,
contexts::Vararg{DI.Context, C},
) where {C}
DI.check_prep(f, prep, backend, x, tx, contexts...)
(; relstep_h, absstep_h, hess) = prep
fc = DI.fix_tail(f, map(DI.unwrap, contexts)...)
finite_difference_hessian!(
hess, fc, x, prep.hessian_cache; relstep = relstep_h, absstep = absstep_h
)
tg = map(tx) do dx
reshape(hess * vec(dx), size(x))
end
return tg
end

function DI.hvp!(
f,
tg::NTuple,
prep::FiniteDiffHVPPrep,
backend::AutoFiniteDiff,
x,
tx::NTuple,
contexts::Vararg{DI.Context, C},
) where {C}
DI.check_prep(f, prep, backend, x, tx, contexts...)
(; relstep_h, absstep_h, hess) = prep
fc = DI.fix_tail(f, map(DI.unwrap, contexts)...)
finite_difference_hessian!(
hess, fc, x, prep.hessian_cache; relstep = relstep_h, absstep = absstep_h
)
for b in eachindex(tx, tg)
mul!(vec(tg[b]), hess, vec(tx[b]))
end
return tg
end

function DI.gradient_and_hvp(
f,
prep::FiniteDiffHVPPrep,
backend::AutoFiniteDiff,
x,
tx::NTuple,
contexts::Vararg{DI.Context, C},
) where {C}
DI.check_prep(f, prep, backend, x, tx, contexts...)
(; relstep_g, absstep_g, relstep_h, absstep_h, hess) = prep
fc = DI.fix_tail(f, map(DI.unwrap, contexts)...)
grad = finite_difference_gradient(
fc, x, prep.gradient_cache; relstep = relstep_g, absstep = absstep_g
)
finite_difference_hessian!(
hess, fc, x, prep.hessian_cache; relstep = relstep_h, absstep = absstep_h
)
tg = map(tx) do dx
reshape(hess * vec(dx), size(x))
end
return grad, tg
end

function DI.gradient_and_hvp!(
f,
grad,
tg::NTuple,
prep::FiniteDiffHVPPrep,
backend::AutoFiniteDiff,
x,
tx::NTuple,
contexts::Vararg{DI.Context, C},
) where {C}
DI.check_prep(f, prep, backend, x, tx, contexts...)
(; relstep_g, absstep_g, relstep_h, absstep_h, hess) = prep
fc = DI.fix_tail(f, map(DI.unwrap, contexts)...)
finite_difference_gradient!(
grad, fc, x, prep.gradient_cache; relstep = relstep_g, absstep = absstep_g
)
finite_difference_hessian!(
hess, fc, x, prep.hessian_cache; relstep = relstep_h, absstep = absstep_h
)
for b in eachindex(tx, tg)
mul!(vec(tg[b]), hess, vec(tx[b]))
end
return grad, tg
end
58 changes: 56 additions & 2 deletions DifferentiationInterface/test/Back/FiniteDiff/test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ end
use_tuples = true,
include_smaller = true,
);
excluded = [:second_derivative, :hvp],
excluded = [:second_derivative],
logging = LOGGING,
)

Expand All @@ -43,7 +43,7 @@ end
AutoFiniteDiff(; relstep = cbrt(eps(Float64)), absstep = cbrt(eps(Float64))),
AutoFiniteDiff(; dir = 0.5),
];
excluded = [:second_derivative, :hvp],
excluded = [:second_derivative],
logging = LOGGING,
)
end
Expand Down Expand Up @@ -90,6 +90,11 @@ end;
@test prep.absstep_h == 1000
@test prep.relstep_g == 0.1
@test prep.relstep_h == 0.1
prep = prepare_hvp(sum, backend, [1.0], ([1.0],))
@test prep.absstep_g == 1000
@test prep.absstep_h == 1000
@test prep.relstep_g == 0.1
@test prep.relstep_h == 0.1

backend = AutoFiniteDiff(; relstep = 0.1)
preps = [
Expand All @@ -110,6 +115,55 @@ end;
@test prep.absstep_h == 0.1
@test prep.relstep_g == 0.1
@test prep.relstep_h == 0.1
prep = prepare_hvp(sum, backend, [1.0], ([1.0],))
@test prep.absstep_g == 0.1
@test prep.absstep_h == 0.1
@test prep.relstep_g == 0.1
@test prep.relstep_h == 0.1
end

@testset "HVP accuracy (issue 1012)" begin
# hvp should match hessian * v for default AutoFiniteDiff()
# Previously, hvp used fdtype (forward) while hessian used fdhtype (central),
# causing significant accuracy differences
backend = AutoFiniteDiff()

for (f, x, v) in [
(x -> sum(x .^ 2), [1.0, 2.0, 3.0], [1.0, 0.0, 0.0]),
(x -> sum(x .^ 3), [1.0, 2.0, 3.0], [1.0, 0.0, 0.0]),
(x -> sum(x .^ 4), [1.0, 2.0, 3.0], [1.0, 0.0, 0.0]),
(x -> x' * [1 2; 3 4] * x, [1.0, 2.0], [1.0, 0.0]),
]
H = hessian(f, backend, x)
Hv_direct = H * v
Hv_hvp = hvp(f, backend, x, (v,))[1]
@test Hv_hvp ≈ Hv_direct rtol = 1e-10
end

# Also test hvp!, gradient_and_hvp, gradient_and_hvp!
f(x) = sum(x .^ 2)
x = [1.0, 2.0, 3.0]
v = [1.0, 0.0, 0.0]
H = hessian(f, backend, x)
expected_Hv = H * v
expected_grad = [2.0, 4.0, 6.0]

# hvp!
tg = (similar(x),)
hvp!(f, tg, backend, x, (v,))
@test tg[1] ≈ expected_Hv rtol = 1e-10

# gradient_and_hvp
grad, tg = gradient_and_hvp(f, backend, x, (v,))
@test grad ≈ expected_grad rtol = 1e-6
@test tg[1] ≈ expected_Hv rtol = 1e-10

# gradient_and_hvp!
grad = similar(x)
tg = (similar(x),)
gradient_and_hvp!(f, grad, tg, backend, x, (v,))
@test grad ≈ expected_grad rtol = 1e-6
@test tg[1] ≈ expected_Hv rtol = 1e-10
end

include("benchmark.jl")
Loading