Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "TensorOperations"
uuid = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2"
version = "5.4"
version = "5.5.0"
authors = ["Lukas Devos <lukas.devos@ugent.be>", "Maarten Van Damme <maartenvd1994@gmail.com>", "Jutho Haegeman <jutho.haegeman@ugent.be>"]

[deps]
Expand All @@ -23,11 +23,13 @@ Bumper = "8ce10254-0962-460f-a3d8-1f77fea1446e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"

[extensions]
TensorOperationsBumperExt = "Bumper"
TensorOperationsChainRulesCoreExt = "ChainRulesCore"
TensorOperationscuTENSORExt = ["cuTENSOR", "CUDA"]
TensorOperationsMooncakeExt = "Mooncake"

[compat]
Aqua = "0.6, 0.7, 0.8"
Expand All @@ -39,6 +41,7 @@ DynamicPolynomials = "0.5, 0.6"
LRUCache = "1"
LinearAlgebra = "1.6"
Logging = "1.6"
Mooncake = "0.4.195"
PackageExtensionCompat = "1"
PrecompileTools = "1.1"
Preferences = "1.4"
Expand All @@ -59,9 +62,10 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
DynamicPolynomials = "7c1d4256-1411-5781-91ec-d7bc3513ac07"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1"

[targets]
test = ["Test", "Random", "DynamicPolynomials", "ChainRulesTestUtils", "CUDA", "cuTENSOR", "Aqua", "Logging", "Bumper"]
test = ["Test", "Random", "DynamicPolynomials", "ChainRulesTestUtils", "CUDA", "cuTENSOR", "Aqua", "Logging", "Bumper", "Mooncake"]
15 changes: 11 additions & 4 deletions docs/src/man/autodiff.md
Original file line number Diff line number Diff line change
@@ -1,25 +1,32 @@
# Automatic differentiation

TensorOperations offers experimental support for reverse-mode automatic diffentiation (AD)
through the use of [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl). As the basic
through the use of [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl)
and [Mooncake.jl](https://github.com/chalk-lab/Mooncake.jl). As the basic
operations are multi-linear, the vector-Jacobian products thereof can all be expressed in
terms of the operations defined in VectorInterface and TensorOperations. Thus, any custom
type whose tangent type also support these interfaces will automatically inherit
reverse-mode AD support.

As the [`@tensor`](@ref) macro rewrites everything in terms of the basic tensor operations,
the reverse-mode rules for these methods are supplied. However, because most AD-engines do
the reverse-mode rules for these methods are supplied. However, because ChainRules.jl does
not support in-place mutation, effectively these operations will be replaced with a
non-mutating version. This is similar to the behaviour found in
[BangBang.jl](https://github.com/JuliaFolds/BangBang.jl), as the operations will be
in-place, except for the pieces of code that are being differentiated. In effect, this
amounts to replacing all assignments (`=`) with definitions (`:=`) within the context of
[`@tensor`](@ref).

Mooncake.jl *does* support in-place mutation, and as a result on the reverse pass
all mutated input variables should be restored to their state before the forward-pass
function was called. Currently, this is **not done** for buffers you provide to various
TensorOperations functions, so relying on the state of the buffer (e.g. a bumper) being
restored will **silently** return incorrect results.

!!! warning "Experimental"

While some rudimentary tests are run, the AD support is currently not incredibly
well-tested. Because of the way it is implemented, the use of AD will tacitly replace
mutating operations with a non-mutating variant. This might lead to unwanted bugs that
are hard to track down. Additionally, for mixed scalar types their also might be
unexpected or unwanted behaviour.
are hard to track down. Additionally, for mixed scalar types there also might be
unexpected or unwanted behaviour.
20 changes: 2 additions & 18 deletions ext/TensorOperationsChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
module TensorOperationsChainRulesCoreExt

using TensorOperations
using TensorOperations: numind, numin, numout, promote_contract
using TensorOperations: DefaultBackend, DefaultAllocator
using TensorOperations: numind, numin, numout, promote_contract, _needs_tangent
using TensorOperations: DefaultBackend, DefaultAllocator, _kron
using ChainRulesCore
using TupleTools
using VectorInterface
Expand Down Expand Up @@ -55,13 +55,6 @@ function ChainRulesCore.rrule(::typeof(tensorscalar), C)
return tensorscalar(C), tensorscalar_pullback
end

# To avoid computing rrules for α and β when these aren't needed, we want to have a
# type-stable quick bail-out
_needs_tangent(x) = _needs_tangent(typeof(x))
_needs_tangent(::Type{<:Number}) = true
_needs_tangent(::Type{<:Integer}) = false
_needs_tangent(::Type{<:Union{One, Zero}}) = false

# The current `rrule` design makes sure that the implementation for custom types does
# not need to support the backend or allocator arguments
function ChainRulesCore.rrule(
Expand Down Expand Up @@ -309,15 +302,6 @@ function _rrule_tensortrace!(C, A, p, q, conjA, α, β, ba)
return C′, pullback
end

_kron(Es::NTuple{1}, ba) = Es[1]
function _kron(Es::NTuple{N, Any}, ba) where {N}
E1 = Es[1]
E2 = _kron(Base.tail(Es), ba)
p2 = ((), trivtuple(2 * N - 2))
p = ((1, (2 .+ trivtuple(N - 1))...), (2, ((N + 1) .+ trivtuple(N - 1))...))
return tensorproduct(p, E1, ((1, 2), ()), false, E2, p2, false, One(), ba...)
end

# NCON functions
@non_differentiable TensorOperations.ncontree(args...)
@non_differentiable TensorOperations.nconoutput(args...)
Expand Down
265 changes: 265 additions & 0 deletions ext/TensorOperationsMooncakeExt/TensorOperationsMooncakeExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,265 @@
module TensorOperationsMooncakeExt

using TensorOperations
# Mooncake imports ChainRulesCore as CRC to avoid name conflicts
# here we import it ourselves to ensure the rules from the ChainRulesCore
# extension are in fact loaded
using Mooncake, Mooncake.CRC
using TensorOperations: AbstractBackend, DefaultAllocator, CUDAAllocator, ManualAllocator
using TensorOperations: tensoralloc, tensoradd!, tensorcontract!, tensortrace!, _kron, numind, _needs_tangent, numin, numout
using Mooncake: ReverseMode, DefaultCtx, CoDual, NoRData, arrayify, @zero_derivative, primal, tangent
using VectorInterface, TupleTools

Mooncake.tangent_type(::Type{Index2Tuple}) = Mooncake.NoTangent
Mooncake.tangent_type(::Type{<:AbstractBackend}) = Mooncake.NoTangent
Mooncake.tangent_type(::Type{DefaultAllocator}) = Mooncake.NoTangent
Mooncake.tangent_type(::Type{CUDAAllocator}) = Mooncake.NoTangent
Mooncake.tangent_type(::Type{ManualAllocator}) = Mooncake.NoTangent

trivtuple(N) = ntuple(identity, N)

@zero_derivative DefaultCtx Tuple{typeof(TensorOperations.tensorstructure), Any}
@zero_derivative DefaultCtx Tuple{typeof(TensorOperations.tensoradd_structure), Any}
@zero_derivative DefaultCtx Tuple{typeof(TensorOperations.tensoradd_type), Any}
@zero_derivative DefaultCtx Tuple{typeof(TensorOperations.tensoralloc_add), Any}
@zero_derivative DefaultCtx Tuple{typeof(TensorOperations.tensorcontract_structure), Any}
@zero_derivative DefaultCtx Tuple{typeof(TensorOperations.tensorcontract_type), Any}
@zero_derivative DefaultCtx Tuple{typeof(TensorOperations.tensoralloc_contract), Any}
@zero_derivative DefaultCtx Tuple{typeof(TensorOperations.promote_contract), Any}
@zero_derivative DefaultCtx Tuple{typeof(TensorOperations.promote_add), Any}

Mooncake.@from_rrule Mooncake.DefaultCtx Tuple{typeof(TensorOperations.tensorfree!), Any}
Mooncake.@from_rrule Mooncake.DefaultCtx Tuple{typeof(TensorOperations.tensoralloc), Any, Any, Any, Any}

Mooncake.@is_primitive DefaultCtx ReverseMode Tuple{typeof(tensorcontract!), AbstractArray, AbstractArray, Index2Tuple, Bool, AbstractArray, Index2Tuple, Bool, Index2Tuple, Number, Number, Vararg{Any}}
function Mooncake.rrule!!(
::CoDual{typeof(tensorcontract!)},
C_dC::CoDual{<:AbstractArray{TC}},
A_dA::CoDual{<:AbstractArray{TA}},
pA_dpA::CoDual{<:Index2Tuple},
conjA_dconjA::CoDual{Bool},
B_dB::CoDual{<:AbstractArray{TB}},
pB_dpB::CoDual{<:Index2Tuple},
conjB_dconjB::CoDual{Bool},
pAB_dpAB::CoDual{<:Index2Tuple},
α_dα::CoDual{Tα},
β_dβ::CoDual{Tβ},
ba_dba::CoDual...,
) where {Tα <: Number, Tβ <: Number, TA <: Number, TB <: Number, TC <: Number}
C, dC = arrayify(C_dC)
A, dA = arrayify(A_dA)
B, dB = arrayify(B_dB)
pA = primal(pA_dpA)
pB = primal(pB_dpB)
pAB = primal(pAB_dpAB)
conjA = primal(conjA_dconjA)
conjB = primal(conjB_dconjB)
α = primal(α_dα)
β = primal(β_dβ)
ba = primal.(ba_dba)
C_cache = copy(C)
TensorOperations.tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, ba...)
function contract_pb(::NoRData)
scale!(C, C_cache, One())
if Tα == Zero && Tβ == Zero
scale!(dC, zero(TC))
return ntuple(i -> NoRData(), 11 + length(ba))
end
ipAB = invperm(linearize(pAB))
pdC = (
TupleTools.getindices(ipAB, trivtuple(numout(pA))),
TupleTools.getindices(ipAB, numout(pA) .+ trivtuple(numin(pB))),
)
ipA = (invperm(linearize(pA)), ())
ipB = (invperm(linearize(pB)), ())
conjΔC = conjA
conjB′ = conjA ? conjB : !conjB
dA = tensorcontract!(
dA,
dC, pdC, conjΔC,
B, reverse(pB), conjB′,
ipA,
conjA ? α : conj(α), One(), ba...
)
conjΔC = conjB
conjA′ = conjB ? conjA : !conjA
dB = tensorcontract!(
dB,
A, reverse(pA), conjA′,
dC, pdC, conjΔC,
ipB,
conjB ? α : conj(α), One(), ba...
)
dα = if _needs_tangent(Tα)
C_αβ = tensorcontract(A, pA, conjA, B, pB, conjB, pAB, One(), ba...)
# TODO: consider using `inner`
Mooncake._rdata(
tensorscalar(
tensorcontract(
C_αβ, ((), trivtuple(numind(pAB))), true,
dC, (trivtuple(numind(pAB)), ()), false,
((), ()), One(), ba...
)
)
)
else
NoRData()
end
dβ = if _needs_tangent(Tβ)
# TODO: consider using `inner`
Mooncake._rdata(
tensorscalar(
tensorcontract(
C, ((), trivtuple(numind(pAB))), true,
dC, (trivtuple(numind(pAB)), ()), false,
((), ()), One(), ba...
)
)
)
else
NoRData()
end
if β === Zero()
scale!(dC, β)
else
scale!(dC, conj(β))
end
return NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), dα, dβ, map(ba_ -> NoRData(), ba)...
end
return C_dC, contract_pb
end

Mooncake.@is_primitive DefaultCtx ReverseMode Tuple{typeof(tensoradd!), AbstractArray, AbstractArray, Index2Tuple, Bool, Number, Number, Vararg{Any}}
function Mooncake.rrule!!(
::CoDual{typeof(tensoradd!)},
C_dC::CoDual{<:AbstractArray{TC}},
A_dA::CoDual{<:AbstractArray{TA}},
pA_dpA::CoDual{<:Index2Tuple},
conjA_dconjA::CoDual{Bool},
α_dα::CoDual{Tα},
β_dβ::CoDual{Tβ},
ba_dba::CoDual...,
) where {Tα <: Number, Tβ <: Number, TA <: Number, TC <: Number}
C, dC = arrayify(C_dC)
A, dA = arrayify(A_dA)
pA = primal(pA_dpA)
conjA = primal(conjA_dconjA)
α = primal(α_dα)
β = primal(β_dβ)
ba = primal.(ba_dba)
C_cache = copy(C)
TensorOperations.tensoradd!(C, A, pA, conjA, α, β, ba...)
function add_pb(::NoRData)
scale!(C, C_cache, One())
ipA = invperm(linearize(pA))
dA = tensoradd!(dA, dC, (ipA, ()), conjA, conjA ? α : conj(α), One(), ba...)
dα = if _needs_tangent(Tα)
tensorscalar(
tensorcontract(
A, ((), linearize(pA)), !conjA,
dC, (trivtuple(numind(pA)), ()), false,
((), ()), One(), ba...
)
)
else
Mooncake.NoRData()
end
dβ = if _needs_tangent(Tβ)
tensorscalar(
tensorcontract(
C, ((), trivtuple(numind(pA))), true,
dC, (trivtuple(numind(pA)), ()), false,
((), ()), One(), ba...
)
)
else
Mooncake.NoRData()
end
if β === Zero()
scale!(dC, β)
else
scale!(dC, conj(β))
end
return NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), dα, dβ, map(ba_ -> NoRData(), ba)...
end
return C_dC, add_pb
end

Mooncake.@is_primitive DefaultCtx ReverseMode Tuple{typeof(tensortrace!), AbstractArray, AbstractArray, Index2Tuple, Index2Tuple, Bool, Number, Number, Vararg{Any}}
function Mooncake.rrule!!(
::CoDual{typeof(tensortrace!)},
C_dC::CoDual{<:AbstractArray{TC}},
A_dA::CoDual{<:AbstractArray{TA}},
p_dp::CoDual{<:Index2Tuple},
q_dq::CoDual{<:Index2Tuple},
conjA_dconjA::CoDual{Bool},
α_dα::CoDual{Tα},
β_dβ::CoDual{Tβ},
ba_dba::CoDual...,
) where {Tα <: Number, Tβ <: Number, TA <: Number, TC <: Number}
C, dC = arrayify(C_dC)
A, dA = arrayify(A_dA)
p = primal(p_dp)
q = primal(q_dq)
conjA = primal(conjA_dconjA)
α = primal(α_dα)
β = primal(β_dβ)
ba = primal.(ba_dba)
C_cache = copy(C)
TensorOperations.tensortrace!(C, A, p, q, conjA, α, β, ba...)
function trace_pb(::NoRData)
scale!(C, C_cache, One())
ip = invperm((linearize(p)..., q[1]..., q[2]...))
Es = map(q[1], q[2]) do i1, i2
one(
TensorOperations.tensoralloc_add(
TensorOperations.scalartype(A), A, ((i1,), (i2,)), conjA
)
)
end
E = _kron(Es, ba)
dA = tensorproduct!(
dA, dC, (trivtuple(numind(p)), ()), conjA,
E, ((), trivtuple(numind(q))), conjA,
(ip, ()),
conjA ? α : conj(α), One(), ba...
)
C_αβ = tensortrace(A, p, q, false, One(), ba...)
dα = if _needs_tangent(Tα)
Mooncake._rdata(
tensorscalar(
tensorcontract(
C_αβ, ((), trivtuple(numind(p))),
!conjA,
dC, (trivtuple(numind(p)), ()), false,
((), ()), One(), ba...
)
)
)
else
NoRData()
end
dβ = if _needs_tangent(Tβ)
Mooncake._rdata(
tensorscalar(
tensorcontract(
C, ((), trivtuple(numind(p))), true,
dC, (trivtuple(numind(p)), ()), false,
((), ()), One(), ba...
)
)
)
else
NoRData()
end
if β === Zero()
scale!(dC, β)
else
scale!(dC, conj(β))
end
return NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), dα, dβ, map(ba_ -> NoRData(), ba)...
end
return C_dC, trace_pb
end

end
1 change: 1 addition & 0 deletions src/TensorOperations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ export checkcontractible, tensorcost
include("indices.jl")
include("backends.jl")
include("interface.jl")
include("utils.jl")

# Index notation via macros
#---------------------------
Expand Down
Loading
Loading