Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 96 additions & 59 deletions ext/TensorOperationsChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,28 @@ function ChainRulesCore.rrule(
return output, tensoralloc_pullback
end

# this function more or less boils down to `fill!(similar(x), y)` but does so in a single
# call to allow higher-order derivatives
function similar_and_fill(x, y)
x′ = TensorOperations.tensoralloc(typeof(x), TensorOperations.tensorstructure(x))
return fill!(x′, y)
end
function ChainRulesCore.rrule(::typeof(similar_and_fill), x, y)
similar_and_fill_pullback(Δx) = NoTangent(), ZeroTangent(), tensorscalar(unthunk(Δx))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am confused by this rule, in particular the adjoint of y: I think I can reinterpret the output of similar_and_fill(x, y) as just y * similar_and_fill(x, 1).

To avoid confusion, let's say x = y * similar_and_fill(some_other_x, 1). Then clearly forward derivatives satisfy ẋ = ẏ * similar_and_fill(some_other_x, 1), where the last factor is completely constant.

So then I obtain from equation dot(Δx, ẋ) = ẏ * dot(Δx, similar_and_fill(some_other_x, 1)) to Δy' * ẏ that

Δy = dot(similar_and_fill(some_other_x, 1), Δx)

Maybe I have to first read further, and similar_and_fill is only ever called on tensor arguments x that are equivalent to scalars, and thus have only a single entry. But in principle, the definition makes sense for general tensors, but then the reverse rule can clearly not be correct since tensorscalar(Δx) would fail.

return similar_and_fill(x, y), similar_and_fill_pullback
end
function ChainRulesCore.rrule(::typeof(tensorscalar), C)
projectC = ProjectTo(C)
function tensorscalar_pullback(Δc)
_Δc = unthunk(Δc)
return NoTangent(), projectC(_Δc)
end
tensorscalar_pullback(Δc) = NoTangent(), similar_and_fill(C, unthunk(Δc))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I see, so similar_and_fill is indeed only called on tensors C for which tensorscalar makes sense.

return tensorscalar(C), tensorscalar_pullback
end

# To avoid computing rrules for α and β when these aren't needed, we want to have a
# type-stable quick bail-out
_needs_tangent(x) = _needs_tangent(typeof(x))
_needs_tangent(::Type{<:Number}) = true
_needs_tangent(::Type{<:Integer}) = false
_needs_tangent(::Type{<:Union{One, Zero}}) = false

# The current `rrule` design makes sure that the implementation for custom types does
# not need to support the backend or allocator arguments
# function ChainRulesCore.rrule(::typeof(TensorOperations.tensoradd!),
Expand Down Expand Up @@ -99,26 +112,34 @@ function _rrule_tensoradd!(C, A, pA, conjA, α, β, ba)
_dA = tensoradd!(_dA, ΔC, (ipA, ()), conjA, conjA ? α : conj(α), Zero(), ba...)
projectA(_dA)
end
dα = @thunk let
_dα = tensorscalar(
tensorcontract(
A, ((), linearize(pA)), !conjA,
ΔC, (trivtuple(numind(pA)), ()), false,
((), ()), One(), ba...
dα = if _needs_tangent(α)
@thunk let
_dα = tensorscalar(
tensorcontract(
A, ((), linearize(pA)), !conjA,
ΔC, (trivtuple(numind(pA)), ()), false,
((), ()), One(), ba...
)
)
)
projectα(_dα)
projectα(_dα)
end
else
ZeroTangent()
end
dβ = @thunk let
# TODO: consider using `inner`
_dβ = tensorscalar(
tensorcontract(
C, ((), trivtuple(numind(pA))), true,
ΔC, (trivtuple(numind(pA)), ()), false,
((), ()), One(), ba...
dβ = if _needs_tangent(β)
@thunk let
# TODO: consider using `inner`
_dβ = tensorscalar(
tensorcontract(
C, ((), trivtuple(numind(pA))), true,
ΔC, (trivtuple(numind(pA)), ()), false,
((), ()), One(), ba...
)
)
)
projectβ(_dβ)
projectβ(_dβ)
end
else
ZeroTangent()
end
dba = map(_ -> NoTangent(), ba)
return NoTangent(), dC, dA, NoTangent(), NoTangent(), dα, dβ, dba...
Expand Down Expand Up @@ -212,28 +233,36 @@ function _rrule_tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, ba)
)
projectB(_dB)
end
dα = @thunk let
C_αβ = tensorcontract(A, pA, conjA, B, pB, conjB, pAB, One(), ba...)
# TODO: consider using `inner`
_dα = tensorscalar(
tensorcontract(
C_αβ, ((), trivtuple(numind(pAB))), true,
ΔC, (trivtuple(numind(pAB)), ()), false,
((), ()), One(), ba...
dα = if _needs_tangent(α)
@thunk let
C_αβ = tensorcontract(A, pA, conjA, B, pB, conjB, pAB, One(), ba...)
# TODO: consider using `inner`
_dα = tensorscalar(
tensorcontract(
C_αβ, ((), trivtuple(numind(pAB))), true,
ΔC, (trivtuple(numind(pAB)), ()), false,
((), ()), One(), ba...
)
)
)
projectα(_dα)
projectα(_dα)
end
else
ZeroTangent()
end
dβ = @thunk let
# TODO: consider using `inner`
_dβ = tensorscalar(
tensorcontract(
C, ((), trivtuple(numind(pAB))), true,
ΔC, (trivtuple(numind(pAB)), ()), false,
((), ()), One(), ba...
dβ = if _needs_tangent(β)
@thunk let
# TODO: consider using `inner`
_dβ = tensorscalar(
tensorcontract(
C, ((), trivtuple(numind(pAB))), true,
ΔC, (trivtuple(numind(pAB)), ()), false,
((), ()), One(), ba...
)
)
)
projectβ(_dβ)
projectβ(_dβ)
end
else
ZeroTangent()
end
dba = map(_ -> NoTangent(), ba)
return NoTangent(), dC,
Expand Down Expand Up @@ -301,27 +330,35 @@ function _rrule_tensortrace!(C, A, p, q, conjA, α, β, ba)
)
projectA(_dA)
end
dα = @thunk let
C_αβ = tensortrace(A, p, q, false, One(), ba...)
_dα = tensorscalar(
tensorcontract(
C_αβ, ((), trivtuple(numind(p))),
!conjA,
ΔC, (trivtuple(numind(p)), ()), false,
((), ()), One(), ba...
dα = if _needs_tangent(α)
@thunk let
C_αβ = tensortrace(A, p, q, false, One(), ba...)
_dα = tensorscalar(
tensorcontract(
C_αβ, ((), trivtuple(numind(p))),
!conjA,
ΔC, (trivtuple(numind(p)), ()), false,
((), ()), One(), ba...
)
)
)
projectα(_dα)
projectα(_dα)
end
else
ZeroTangent()
end
dβ = @thunk let
_dβ = tensorscalar(
tensorcontract(
C, ((), trivtuple(numind(p))), true,
ΔC, (trivtuple(numind(p)), ()), false,
((), ()), One(), ba...
dβ = if _needs_tangent(β)
@thunk let
_dβ = tensorscalar(
tensorcontract(
C, ((), trivtuple(numind(p))), true,
ΔC, (trivtuple(numind(p)), ()), false,
((), ()), One(), ba...
)
)
)
projectβ(_dβ)
projectβ(_dβ)
end
else
ZeroTangent()
end
dba = map(_ -> NoTangent(), ba)
return NoTangent(), dC, dA, NoTangent(), NoTangent(), NoTangent(), dα, dβ, dba...
Expand Down
Loading