Skip to content

add ChainRulesCore rules#3

Open
mileslucas wants to merge 12 commits intomainfrom
ml/grads
Open

add ChainRulesCore rules#3
mileslucas wants to merge 12 commits intomainfrom
ml/grads

Conversation

@mileslucas
Copy link
Copy Markdown
Member

This PR adds analytical gradients using ChainRulesCore.jl

@codecov
Copy link
Copy Markdown

codecov Bot commented Sep 11, 2021

Codecov Report

Merging #3 (eb7de41) into main (1399529) will not change coverage.
The diff coverage is n/a.

❗ Current head eb7de41 differs from pull request most recent head d46340a. Consider uploading reports for the commit d46340a to get more accurate results
Impacted file tree graph

@@           Coverage Diff           @@
##             main       #3   +/-   ##
=======================================
  Coverage   98.80%   98.80%           
=======================================
  Files           6        6           
  Lines          84       84           
=======================================
  Hits           83       83           
  Misses          1        1           

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 1399529...d46340a. Read the comment docs.

@mileslucas
Copy link
Copy Markdown
Member Author

I don't understand why the chain rule tests are failing. Let's look at the isotropic Gaussian PSF as an example

Here is the definition of the gradient

# isotropic
function fgrad(g::Gaussian, point::AbstractVector)
f = g(point)
xdiff = first(point) - first(g.pos)
ydiff = last(point) - last(g.pos)
dfdpos = -2 * GAUSS_PRE * f / g.fwhm^2 .* SA[xdiff, ydiff]
dfdfwhm = -2 * GAUSS_PRE * f * (xdiff^2 + ydiff^2) / g.fwhm^3
dfdamp = f / g.amp
return f, dfdpos, dfdfwhm, dfdamp
end

which I wrote out by hand and can be verified with this derivation http://umdberg.pbworks.com/w/page/88516931/Example%3A%20Gradient%20of%20a%20Gaussian

here are the chain rules

function frule((Δpsf, Δp), g::Gaussian, point::AbstractVector)
f, dfdpos, dfdfwhm, dfda = fgrad(g, point)
Δf = dot(dfdpos, Δpsf.pos) + dot(dfdfwhm, Δpsf.fwhm) + dfda * Δpsf.amp
Δf -= dot(dfdpos, Δp)
return f, Δf
end
function rrule(g::G, point::AbstractVector) where {G<:Gaussian}
f, dfdpos, dfdfwhm, dfda = fgrad(g, point)
function Gaussian_pullback(Δf)
∂pos = dfdpos .* Δf
∂fwhm = dfdfwhm .* Δf
∂g = Tangent{G}(pos=∂pos, fwhm=∂fwhm, amp=dfda * Δf, indices=ZeroTangent())
∂pos = dfdpos .* -Δf
return ∂g, ∂pos
end
return f, Gaussian_pullback
end

using them works as intended-

using ChainRulescore, PSFModels
psf = PSFModels.Gaussian(fwhm=10)
point = [1, 2]
f, pullback = rrule(psf, point)
Δpsf, Δpoint = pullback(1.0)
f2, Δf = frule((Δpsf, Δpoint), psf, point)

# output
(0.8705505632961241, 0.7817442466933209)

but using test_frule and test_rrule consistently fails

@testset "gradients" begin
# have to make sure PSFs are all floating point so tangents don't have type issues
psf_iso = Gaussian(fwhm=10.0, pos=zeros(2))
psf_tang = Tangent{Gaussian}(fwhm=rand(rng), pos=rand(rng, 2), amp=rand(rng), indices=ZeroTangent())
point = Float64[1, 2]
test_frule(psf_iso psf_tang, point)
test_rrule(psf_iso psf_tang, point)
psf_diag = Gaussian(fwhm=Float64[10, 8], pos=zeros(2))
psf_tang = Tangent{Gaussian}(fwhm=rand(rng, 2), pos=rand(rng, 2), amp=rand(rng), indices=ZeroTangent())
test_frule(psf_diag psf_tang, point)
test_rrule(psf_diag psf_tang, point)
end

@abhro
Copy link
Copy Markdown
Member

abhro commented Jan 10, 2026

The merge commits mostly tried to make the code runnable, but I don't think it still works with ChainRuleCore's newer API. The code needs to be reworked to comply

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants