Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "TensorOperations"
uuid = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2"
version = "5.3.2"
version = "5.4"
authors = ["Lukas Devos <lukas.devos@ugent.be>", "Maarten Van Damme <maartenvd1994@gmail.com>", "Jutho Haegeman <jutho.haegeman@ugent.be>"]

[deps]
Expand Down Expand Up @@ -49,7 +49,7 @@ StridedViews = "0.3, 0.4"
Test = "1"
TupleTools = "1.6"
VectorInterface = "0.4.1,0.5"
cuTENSOR = ">=2.1.1"
cuTENSOR = "2.1.1"
julia = "1.8"

[extras]
Expand Down
88 changes: 15 additions & 73 deletions ext/TensorOperationsChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,28 +64,6 @@ _needs_tangent(::Type{<:Union{One, Zero}}) = false

# The current `rrule` design makes sure that the implementation for custom types does
# not need to support the backend or allocator arguments
# function ChainRulesCore.rrule(::typeof(TensorOperations.tensoradd!),
# C,
# A, pA::Index2Tuple, conjA::Bool,
# α::Number, β::Number,
# backend, allocator)
# val, pb = _rrule_tensoradd!(C, A, pA, conjA, α, β, (backend, allocator))
# return val, ΔC -> (pb(ΔC)..., NoTangent(), NoTangent())
# end
# function ChainRulesCore.rrule(::typeof(TensorOperations.tensoradd!),
# C,
# A, pA::Index2Tuple, conjA::Bool,
# α::Number, β::Number,
# backend)
# val, pb = _rrule_tensoradd!(C, A, pA, conjA, α, β, (backend,))
# return val, ΔC -> (pb(ΔC)..., NoTangent())
# end
# function ChainRulesCore.rrule(::typeof(TensorOperations.tensoradd!),
# C,
# A, pA::Index2Tuple, conjA::Bool,
# α::Number, β::Number)
# return _rrule_tensoradd!(C, A, pA, conjA, α, β, ())
# end
function ChainRulesCore.rrule(
::typeof(TensorOperations.tensoradd!),
C,
Expand All @@ -105,7 +83,11 @@ function _rrule_tensoradd!(C, A, pA, conjA, α, β, ba)

function pullback(ΔC′)
ΔC = unthunk(ΔC′)
dC = @thunk projectC(scale(ΔC, conj(β)))
dC = if β === Zero()
ZeroTangent()
else
@thunk projectC(scale(ΔC, conj(β)))
end
dA = @thunk let
ipA = invperm(linearize(pA))
_dA = zerovector(A, VectorInterface.promote_add(ΔC, α))
Expand Down Expand Up @@ -148,35 +130,6 @@ function _rrule_tensoradd!(C, A, pA, conjA, α, β, ba)
return C′, pullback
end

# function ChainRulesCore.rrule(::typeof(TensorOperations.tensorcontract!),
# C,
# A, pA::Index2Tuple, conjA::Bool,
# B, pB::Index2Tuple, conjB::Bool,
# pAB::Index2Tuple,
# α::Number, β::Number,
# backend, allocator)
# val, pb = _rrule_tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β,
# (backend, allocator))
# return val, ΔC -> (pb(ΔC)..., NoTangent(), NoTangent())
# end
# function ChainRulesCore.rrule(::typeof(TensorOperations.tensorcontract!),
# C,
# A, pA::Index2Tuple, conjA::Bool,
# B, pB::Index2Tuple, conjB::Bool,
# pAB::Index2Tuple,
# α::Number, β::Number,
# backend)
# val, pb = _rrule_tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, (backend,))
# return val, ΔC -> (pb(ΔC)..., NoTangent())
# end
# function ChainRulesCore.rrule(::typeof(TensorOperations.tensorcontract!),
# C,
# A, pA::Index2Tuple, conjA::Bool,
# B, pB::Index2Tuple, conjB::Bool,
# pAB::Index2Tuple,
# α::Number, β::Number)
# return _rrule_tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, ())
# end
function ChainRulesCore.rrule(
::typeof(TensorOperations.tensorcontract!),
C,
Expand Down Expand Up @@ -204,7 +157,11 @@ function _rrule_tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, ba)
TupleTools.getindices(ipAB, trivtuple(numout(pA))),
TupleTools.getindices(ipAB, numout(pA) .+ trivtuple(numin(pB))),
)
dC = @thunk projectC(scale(ΔC, conj(β)))
dC = if β === Zero()
ZeroTangent()
else
@thunk projectC(scale(ΔC, conj(β)))
end
dA = @thunk let
ipA = (invperm(linearize(pA)), ())
conjΔC = conjA
Expand Down Expand Up @@ -273,25 +230,6 @@ function _rrule_tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, ba)
return C′, pullback
end

# function ChainRulesCore.rrule(::typeof(tensortrace!), C,
# A, p::Index2Tuple, q::Index2Tuple, conjA::Bool,
# α::Number, β::Number,
# backend, allocator)
# val, pb = _rrule_tensortrace!(C, A, p, q, conjA, α, β, (backend, allocator))
# return val, ΔC -> (pb(ΔC)..., NoTangent(), NoTangent())
# end
# function ChainRulesCore.rrule(::typeof(tensortrace!), C,
# A, p::Index2Tuple, q::Index2Tuple, conjA::Bool,
# α::Number, β::Number,
# backend)
# val, pb = _rrule_tensortrace!(C, A, p, q, conjA, α, β, (backend,))
# return val, ΔC -> (pb(ΔC)..., NoTangent())
# end
# function ChainRulesCore.rrule(::typeof(tensortrace!), C,
# A, p::Index2Tuple, q::Index2Tuple, conjA::Bool,
# α::Number, β::Number)
# return _rrule_tensortrace!(C, A, p, q, conjA, α, β, ())
# end
function ChainRulesCore.rrule(
::typeof(tensortrace!), C,
A, p::Index2Tuple, q::Index2Tuple, conjA::Bool,
Expand All @@ -310,7 +248,11 @@ function _rrule_tensortrace!(C, A, p, q, conjA, α, β, ba)

function pullback(ΔC′)
ΔC = unthunk(ΔC′)
dC = @thunk projectC(scale(ΔC, conj(β)))
dC = if β === Zero()
ZeroTangent()
else
@thunk projectC(scale(ΔC, conj(β)))
end
dA = @thunk let
ip = invperm((linearize(p)..., q[1]..., q[2]...))
Es = map(q[1], q[2]) do i1, i2
Expand Down
Loading