Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,14 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"

[extensions]
TensorOperationsBumperExt = "Bumper"
TensorOperationsChainRulesCoreExt = "ChainRulesCore"
TensorOperationscuTENSORExt = ["cuTENSOR", "CUDA"]
TensorOperationsMooncakeExt = "Mooncake"
TensorOperationsEnzymeExt = ["Enzyme", "ChainRulesCore"]

[compat]
Aqua = "0.6, 0.7, 0.8"
Expand All @@ -38,6 +40,8 @@ CUDA = "5"
ChainRulesCore = "1"
ChainRulesTestUtils = "1"
DynamicPolynomials = "0.5, 0.6"
Enzyme = "0.13.115"
EnzymeTestUtils = "0.2"
LRUCache = "1"
LinearAlgebra = "1.6"
Logging = "1.6"
Expand All @@ -61,11 +65,13 @@ Bumper = "8ce10254-0962-460f-a3d8-1f77fea1446e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
DynamicPolynomials = "7c1d4256-1411-5781-91ec-d7bc3513ac07"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
EnzymeTestUtils = "12d8515a-0907-448a-8884-5fe00fdf1c5a"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1"

[targets]
test = ["Test", "Random", "DynamicPolynomials", "ChainRulesTestUtils", "CUDA", "cuTENSOR", "Aqua", "Logging", "Bumper", "Mooncake"]
test = ["Test", "Random", "DynamicPolynomials", "ChainRulesTestUtils", "CUDA", "cuTENSOR", "Aqua", "Logging", "Bumper", "Mooncake", "Enzyme", "EnzymeTestUtils"]
4 changes: 1 addition & 3 deletions ext/TensorOperationsChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
module TensorOperationsChainRulesCoreExt

using TensorOperations
using TensorOperations: numind, numin, numout, promote_contract, _needs_tangent
using TensorOperations: numind, numin, numout, promote_contract, _needs_tangent, trivtuple
using TensorOperations: DefaultBackend, DefaultAllocator, _kron
using ChainRulesCore
using TupleTools
using VectorInterface
using TupleTools: invperm
using LinearAlgebra

trivtuple(N) = ntuple(identity, N)

@non_differentiable TensorOperations.tensorstructure(args...)
@non_differentiable TensorOperations.tensoradd_structure(args...)
@non_differentiable TensorOperations.tensoradd_type(args...)
Expand Down
197 changes: 197 additions & 0 deletions ext/TensorOperationsEnzymeExt/TensorOperationsEnzymeExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
module TensorOperationsEnzymeExt

using TensorOperations
using TensorOperations: AbstractBackend, DefaultAllocator, CUDAAllocator, ManualAllocator
using VectorInterface
using TupleTools
using Enzyme, ChainRulesCore
using Enzyme.EnzymeCore
using Enzyme.EnzymeCore: EnzymeRules

Enzyme.@import_rrule typeof(TensorOperations.tensorfree!) Any
Enzyme.@import_rrule typeof(TensorOperations.tensoralloc) Any

@inline EnzymeRules.inactive_type(v::Type{<:AbstractBackend}) = true
@inline EnzymeRules.inactive_type(v::Type{DefaultAllocator}) = true
@inline EnzymeRules.inactive_type(v::Type{<:CUDAAllocator}) = true
@inline EnzymeRules.inactive_type(v::Type{ManualAllocator}) = true
@inline EnzymeRules.inactive_type(v::Type{Index2Tuple}) = true

function EnzymeRules.augmented_primal(
config::EnzymeRules.RevConfigWidth{1},
func::Const{typeof(TensorOperations.tensorcontract!)},
::Type{RT},
C_dC::Annotation{<:AbstractArray{TC}},
A_dA::Annotation{<:AbstractArray{TA}},
pA_dpA::Const{<:Index2Tuple},
conjA_dconjA::Const{Bool},
B_dB::Annotation{<:AbstractArray{TB}},
pB_dpB::Const{<:Index2Tuple},
conjB_dconjB::Const{Bool},
pAB_dpAB::Const{<:Index2Tuple},
α_dα::Annotation{Tα},
β_dβ::Annotation{Tβ},
ba_dba::Const...,
) where {RT, Tα <: Number, Tβ <: Number, TA <: Number, TB <: Number, TC <: Number}
# form caches if needed
cache_A = !isa(A_dA, Const) && EnzymeRules.overwritten(config)[3] ? copy(A_dA.val) : nothing
cache_B = !isa(B_dB, Const) && EnzymeRules.overwritten(config)[6] ? copy(B_dB.val) : nothing
cache_C = copy(C_dC.val) # do we need to do this, if we don't need the primal?
ba = map(ba_ -> getfield(ba_, :val), ba_dba)
primal = if EnzymeRules.needs_primal(config)
TensorOperations.tensorcontract!(C_dC.val, A_dA.val, pA_dpA.val, conjA_dconjA.val, B_dB.val, pB_dpB.val, conjB_dconjB.val, pAB_dpAB.val, α_dα.val, β_dβ.val, ba...)
C_dC.val
else
nothing
end
shadow = EnzymeRules.needs_shadow(config) ? C_dC.dval : nothing
return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_B, cache_C))
end

function EnzymeRules.reverse(
config::EnzymeRules.RevConfigWidth{1},
func::Const{typeof(TensorOperations.tensorcontract!)},
::Type{RT},
cache,
C_dC::Annotation{<:AbstractArray{TC}},
A_dA::Annotation{<:AbstractArray{TA}},
pA_dpA::Const{<:Index2Tuple},
conjA_dconjA::Const{Bool},
B_dB::Annotation{<:AbstractArray{TB}},
pB_dpB::Const{<:Index2Tuple},
conjB_dconjB::Const{Bool},
pAB_dpAB::Const{<:Index2Tuple},
α_dα::Annotation{Tα},
β_dβ::Annotation{Tβ},
ba_dba::Const...,
) where {RT, Tα <: Number, Tβ <: Number, TA <: Number, TB <: Number, TC <: Number}
cache_A, cache_B, cache_C = cache
Aval = something(cache_A, A_dA.val)
Bval = something(cache_B, B_dB.val)
Cval = cache_C
dC = C_dC.dval
dA = A_dA.dval
dB = B_dB.dval
ba = map(ba_ -> getfield(ba_, :val), ba_dba)
α = α_dα.val
β = β_dβ.val
dC, dA, dB, dα, dβ = TensorOperations.tensorcontract_pullback!(dC, dA, dB, Cval, Aval, Bval, α, β, pA_dpA.val, pB_dpB.val, pAB_dpAB.val, conjA_dconjA.val, conjB_dconjB.val, ba...)
return nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, dα, dβ, map(ba_ -> nothing, ba)...
end

function EnzymeRules.augmented_primal(
config::EnzymeRules.RevConfigWidth{1},
::Annotation{typeof(tensoradd!)},
::Type{RT},
C_dC::Annotation{<:AbstractArray{TC}},
A_dA::Annotation{<:AbstractArray{TA}},
pA_dpA::Const{<:Index2Tuple},
conjA_dconjA::Const{Bool},
α_dα::Annotation{Tα},
β_dβ::Annotation{Tβ},
ba_dba::Const...,
) where {RT, Tα <: Number, Tβ <: Number, TA <: Number, TC <: Number}
# form caches if needed
cache_A = !isa(A_dA, Const) && EnzymeRules.overwritten(config)[3] ? copy(A_dA.val) : nothing
cache_C = copy(C_dC.val)
ba = map(ba_ -> getfield(ba_, :val), ba_dba)
α = α_dα.val
β = β_dβ.val
conjA = conjA_dconjA.val
primal = if EnzymeRules.needs_primal(config)
TensorOperations.tensoradd!(C_dC.val, A_dA.val, pA_dpA.val, conjA, α, β, ba...)
C_dC.val
else
nothing
end
shadow = EnzymeRules.needs_shadow(config) ? C_dC.dval : nothing
return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_C))
end

function EnzymeRules.reverse(
config::EnzymeRules.RevConfigWidth{1},
::Annotation{typeof(tensoradd!)},
::Type{RT},
cache,
C_dC::Annotation{<:AbstractArray{TC}},
A_dA::Annotation{<:AbstractArray{TA}},
pA_dpA::Const{<:Index2Tuple},
conjA_dconjA::Const{Bool},
α_dα::Annotation{Tα},
β_dβ::Annotation{Tβ},
ba_dba::Const...,
) where {RT, Tα <: Number, Tβ <: Number, TA <: Number, TC <: Number}
cache_A, cache_C = cache
Aval = something(cache_A, A_dA.val)
Cval = cache_C
pA = pA_dpA.val
conjA = conjA_dconjA.val
α = α_dα.val
β = β_dβ.val
ba = map(ba_ -> getfield(ba_, :val), ba_dba)
dC = C_dC.dval
dA = A_dA.dval
dC, dA, dα, dβ = TensorOperations.tensoradd_pullback!(dC, dA, Cval, Aval, α, β, pA, conjA, ba...)
return nothing, nothing, nothing, nothing, dα, dβ, map(ba_ -> nothing, ba)...
end

function EnzymeRules.augmented_primal(
config::EnzymeRules.RevConfigWidth{1},
::Annotation{typeof(tensortrace!)},
::Type{RT},
C_dC::Annotation{<:AbstractArray{TC}},
A_dA::Annotation{<:AbstractArray{TA}},
p_dp::Const{<:Index2Tuple},
q_dq::Const{<:Index2Tuple},
conjA_dconjA::Const{Bool},
α_dα::Annotation{Tα},
β_dβ::Annotation{Tβ},
ba_dba::Const...,
) where {RT, Tα <: Number, Tβ <: Number, TA <: Number, TC <: Number}
# form caches if needed
cache_A = !isa(A_dA, Const) && EnzymeRules.overwritten(config)[3] ? copy(A_dA.val) : nothing
cache_C = copy(C_dC.val)
ba = map(ba_ -> getfield(ba_, :val), ba_dba)
α = α_dα.val
β = β_dβ.val
conjA = conjA_dconjA.val
primal = if EnzymeRules.needs_primal(config)
TensorOperations.tensortrace!(C_dC.val, A_dA.val, p_dp.val, q_dq.val, conjA, α, β, ba...)
C_dC.val
else
nothing
end
shadow = EnzymeRules.needs_shadow(config) ? C_dC.dval : nothing
return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_C))
end

function EnzymeRules.reverse(
config::EnzymeRules.RevConfigWidth{1},
::Annotation{typeof(tensortrace!)},
::Type{RT},
cache,
C_dC::Annotation{<:AbstractArray{TC}},
A_dA::Annotation{<:AbstractArray{TA}},
p_dp::Const{<:Index2Tuple},
q_dq::Const{<:Index2Tuple},
conjA_dconjA::Const{Bool},
α_dα::Annotation{Tα},
β_dβ::Annotation{Tβ},
ba_dba::Const...,
) where {RT, Tα <: Number, Tβ <: Number, TA <: Number, TC <: Number}
cache_A, cache_C = cache
Aval = something(cache_A, A_dA.val)
Cval = cache_C
p = p_dp.val
q = q_dq.val
conjA = conjA_dconjA.val
α = α_dα.val
β = β_dβ.val
ba = map(ba_ -> getfield(ba_, :val), ba_dba)
dC = C_dC.dval
dA = A_dA.dval
dC, dA, dα, dβ = TensorOperations.tensortrace_pullback!(dC, dA, Cval, Aval, α, β, p, q, conjA, ba...)
return nothing, nothing, nothing, nothing, nothing, dα, dβ, map(ba_ -> nothing, ba)...
end

end
Loading
Loading