Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,14 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"

[extensions]
TensorKitAdaptExt = "Adapt"
TensorKitCUDAExt = ["CUDA", "cuTENSOR"]
TensorKitChainRulesCoreExt = "ChainRulesCore"
TensorKitFiniteDifferencesExt = "FiniteDifferences"
TensorKitMooncakeExt = "Mooncake"

[compat]
Adapt = "4"
Expand All @@ -43,6 +45,7 @@ GPUArrays = "11.3.1"
LRUCache = "1.0.2"
LinearAlgebra = "1"
MatrixAlgebraKit = "0.6.2"
Mooncake = "0.4.183"
OhMyThreads = "0.8.0"
Printf = "1"
Random = "1"
Expand Down Expand Up @@ -70,6 +73,7 @@ Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
TensorOperations = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Expand All @@ -78,4 +82,4 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1"

[targets]
test = ["ArgParse", "Adapt", "Aqua", "Combinatorics", "CUDA", "cuTENSOR", "GPUArrays", "LinearAlgebra", "SafeTestsets", "TensorOperations", "Test", "TestExtras", "ChainRulesCore", "ChainRulesTestUtils", "FiniteDifferences", "Zygote"]
test = ["ArgParse", "Adapt", "Aqua", "Combinatorics", "CUDA", "cuTENSOR", "GPUArrays", "LinearAlgebra", "SafeTestsets", "TensorOperations", "Test", "TestExtras", "ChainRulesCore", "ChainRulesTestUtils", "FiniteDifferences", "Zygote", "Mooncake"]
17 changes: 17 additions & 0 deletions ext/TensorKitMooncakeExt/TensorKitMooncakeExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
module TensorKitMooncakeExt

using Mooncake
using Mooncake: @zero_derivative, DefaultCtx, ReverseMode, NoRData, CoDual, arrayify, primal
using TensorKit
using TensorOperations: TensorOperations, IndexTuple, Index2Tuple, linearize
import TensorOperations as TO
using VectorInterface: One, Zero
using TupleTools


include("utility.jl")
include("tangent.jl")
include("linalg.jl")
include("tensoroperations.jl")

end
14 changes: 14 additions & 0 deletions ext/TensorKitMooncakeExt/linalg.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
Mooncake.@is_primitive DefaultCtx ReverseMode Tuple{typeof(norm), AbstractTensorMap, Real}

function Mooncake.rrule!!(::CoDual{typeof(norm)}, tΔt::CoDual{<:AbstractTensorMap}, pdp::CoDual{<:Real})
t, Δt = arrayify(tΔt)
p = primal(pdp)
p == 2 || error("currently only implemented for p = 2")
n = norm(t, p)
function norm_pullback(Δn)
x = (Δn' + Δn) / 2 / hypot(n, eps(one(n)))
add!(Δt, t, x)
return NoRData(), NoRData(), NoRData()
end
return CoDual(n, Mooncake.NoFData()), norm_pullback
end
7 changes: 7 additions & 0 deletions ext/TensorKitMooncakeExt/tangent.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
function Mooncake.arrayify(A_dA::CoDual{<:TensorMap})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was there a benefit to overloading arrayify versus just making a TensorKit specific function tensorify?

A = Mooncake.primal(A_dA)
dA_fw = Mooncake.tangent(A_dA)
data = dA_fw.data.data
dA = typeof(A)(data, A.space)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also does this work in the complex case, where data is now probably using Mooncake's wonderful complex Complex tangent type?

return A, dA
end
137 changes: 137 additions & 0 deletions ext/TensorKitMooncakeExt/tensoroperations.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
Mooncake.@is_primitive(
DefaultCtx,
ReverseMode,
Tuple{
typeof(TO.tensorcontract!),
AbstractTensorMap,
AbstractTensorMap, Index2Tuple, Bool,
AbstractTensorMap, Index2Tuple, Bool,
Index2Tuple,
Number, Number,
Vararg{Any},
}
)

function Mooncake.rrule!!(
::CoDual{typeof(TO.tensorcontract!)},
C_ΔC::CoDual{<:AbstractTensorMap},
A_ΔA::CoDual{<:AbstractTensorMap}, pA_ΔpA::CoDual{<:Index2Tuple}, conjA_ΔconjA::CoDual{Bool},
B_ΔB::CoDual{<:AbstractTensorMap}, pB_ΔpB::CoDual{<:Index2Tuple}, conjB_ΔconjB::CoDual{Bool},
pAB_ΔpAB::CoDual{<:Index2Tuple},
α_Δα::CoDual{<:Number}, β_Δβ::CoDual{<:Number},
ba_Δba::CoDual...,
)
# prepare arguments
(C, ΔC), (A, ΔA), (B, ΔB) = arrayify.((C_ΔC, A_ΔA, B_ΔB))
pA, pB, pAB = primal.((pA_ΔpA, pB_ΔpB, pAB_ΔpAB))
conjA, conjB = primal.((conjA_ΔconjA, conjB_ΔconjB))
α, β = primal.((α_Δα, β_Δβ))
ba = primal.(ba_Δba)

# primal call
C_cache = copy(C)
TO.tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, ba...)

function tensorcontract_pullback(::NoRData)
copy!(C, C_cache)

ΔCr = tensorcontract_pullback_ΔC!(ΔC, β)
ΔAr = tensorcontract_pullback_ΔA!(
ΔA, ΔC, A, pA, conjA, B, pB, conjB, pAB, α, ba...
)
ΔBr = tensorcontract_pullback_ΔB!(
ΔB, ΔC, A, pA, conjA, B, pB, conjB, pAB, α, ba...
)
Δαr = tensorcontract_pullback_Δα(
ΔC, A, pA, conjA, B, pB, conjB, pAB, α, ba...
)
Δβr = tensorcontract_pullback_Δβ(ΔC, C, β)

return NoRData(), ΔCr,
ΔAr, NoRData(), NoRData(),
ΔBr, NoRData(), NoRData(),
NoRData(),
Δαr, Δβr,
map(ba_ -> NoRData(), ba)...
end

return C_ΔC, tensorcontract_pullback
end

tensorcontract_pullback_ΔC!(ΔC, β) = (scale!(ΔC, conj(β)); NoRData())

function tensorcontract_pullback_ΔA!(
ΔA, ΔC, A, pA, conjA, B, pB, conjB, pAB, α, ba...
)
ipAB = invperm(linearize(pAB))
pΔC = _repartition(ipAB, TO.numout(pA))
ipA = _repartition(invperm(linearize(pA)), A)
conjΔC = conjA
conjB′ = conjA ? conjB : !conjB

tB = twist(
B,
TupleTools.vcat(
filter(x -> !isdual(space(B, x)), pB[1]),
filter(x -> isdual(space(B, x)), pB[2])
); copy = false
)

TO.tensorcontract!(
ΔA,
ΔC, pΔC, conjΔC,
tB, reverse(pB), conjB′,
ipA,
conjA ? α : conj(α), Zero(),
ba...
)

return NoRData()
end

function tensorcontract_pullback_ΔB!(
ΔB, ΔC, A, pA, conjA, B, pB, conjB, pAB, α, ba...
)
ipAB = invperm(linearize(pAB))
pΔC = _repartition(ipAB, TO.numout(pA))
ipB = _repartition(invperm(linearize(pB)), B)
conjΔC = conjB
conjA′ = conjB ? conjA : !conjA

tA = twist(
A,
TupleTools.vcat(
filter(x -> isdual(space(A, x)), pA[1]),
filter(x -> !isdual(space(A, x)), pA[2])
); copy = false
)

TO.tensorcontract!(
ΔB,
tA, reverse(pA), conjA′,
ΔC, pΔC, conjΔC,
ipB,
conjB ? α : conj(α), Zero(), ba...
)

return NoRData()
end

function tensorcontract_pullback_Δα(
ΔC, A, pA, conjA, B, pB, conjB, pAB, α, ba...
)
Tdα = Mooncake.rdata_type(Mooncake.tangent_type(typeof(α)))
Tdα === NoRData && return NoRData()

AB = TO.tensorcontract(A, pA, conjA, B, pB, conjB, pAB, One(), ba...)
Δα = inner(AB, ΔC)
return Mooncake._rdata(Δα)
end

function tensorcontract_pullback_Δβ(ΔC, C, β)
Tdβ = Mooncake.rdata_type(Mooncake.tangent_type(typeof(β)))
Tdβ === NoRData && return NoRData()

Δβ = inner(C, ΔC)
return Mooncake._rdata(Δβ)
end
28 changes: 28 additions & 0 deletions ext/TensorKitMooncakeExt/utility.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
_needs_tangent(x) = _needs_tangent(typeof(x))
_needs_tangent(::Type{<:Number}) = true
_needs_tangent(::Type{<:Integer}) = false
_needs_tangent(::Type{<:Union{One, Zero}}) = false

# IndexTuple utility
# ------------------
trivtuple(N) = ntuple(identity, N)

Base.@constprop :aggressive function _repartition(p::IndexTuple, N₁::Int)
length(p) >= N₁ ||
throw(ArgumentError("cannot repartition $(typeof(p)) to $N₁, $(length(p) - N₁)"))
return TupleTools.getindices(p, trivtuple(N₁)),
TupleTools.getindices(p, trivtuple(length(p) - N₁) .+ N₁)
end
Base.@constprop :aggressive function _repartition(p::Index2Tuple, N₁::Int)
return _repartition(linearize(p), N₁)
end
function _repartition(p::Union{IndexTuple, Index2Tuple}, ::Index2Tuple{N₁}) where {N₁}
return _repartition(p, N₁)
end
function _repartition(p::Union{IndexTuple, Index2Tuple}, t::AbstractTensorMap)
return _repartition(p, TensorKit.numout(t))
end

# Ignore derivatives
# ------------------
@zero_derivative DefaultCtx Tuple{typeof(TensorKit.fusionblockstructure), Any}
File renamed without changes.
117 changes: 117 additions & 0 deletions test/autodiff/mooncake.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
using Test, TestExtras
using TensorKit
using TensorOperations
using Mooncake
using Random

mode = Mooncake.ReverseMode
rng = Random.default_rng()
is_primitive = false

function randindextuple(N::Int, k::Int = rand(0:N))
@assert 0 ≤ k ≤ N
_p = randperm(N)
return (tuple(_p[1:k]...), tuple(_p[(k + 1):end]...))
end

const _repartition = @static if isdefined(Base, :get_extension)
Base.get_extension(TensorKit, :TensorKitMooncakeExt)._repartition
else
TensorKit.TensorKitMooncakeExt._repartition
end

spacelist = (
(ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'),
(
Vect[Z2Irrep](0 => 1, 1 => 1),
Vect[Z2Irrep](0 => 1, 1 => 2)',
Vect[Z2Irrep](0 => 2, 1 => 2)',
Vect[Z2Irrep](0 => 2, 1 => 3),
Vect[Z2Irrep](0 => 2, 1 => 2),
),
(
Vect[FermionParity](0 => 1, 1 => 1),
Vect[FermionParity](0 => 1, 1 => 2)',
Vect[FermionParity](0 => 2, 1 => 1)',
Vect[FermionParity](0 => 2, 1 => 3),
Vect[FermionParity](0 => 2, 1 => 2),
),
(
Vect[U1Irrep](0 => 2, 1 => 1, -1 => 1),
Vect[U1Irrep](0 => 2, 1 => 1, -1 => 1),
Vect[U1Irrep](0 => 2, 1 => 2, -1 => 1)',
Vect[U1Irrep](0 => 1, 1 => 1, -1 => 2),
Vect[U1Irrep](0 => 1, 1 => 2, -1 => 1)',
),
(
Vect[SU2Irrep](0 => 2, 1 // 2 => 1),
Vect[SU2Irrep](0 => 1, 1 => 1),
Vect[SU2Irrep](1 // 2 => 1, 1 => 1)',
Vect[SU2Irrep](1 // 2 => 2),
Vect[SU2Irrep](0 => 1, 1 // 2 => 1, 3 // 2 => 1)',
),
(
Vect[FibonacciAnyon](:I => 2, :τ => 1),
Vect[FibonacciAnyon](:I => 1, :τ => 2)',
Vect[FibonacciAnyon](:I => 2, :τ => 2)',
Vect[FibonacciAnyon](:I => 2, :τ => 3),
Vect[FibonacciAnyon](:I => 2, :τ => 2),
),
)

for V in spacelist
I = sectortype(eltype(V))
Istr = TensorKit.type_repr(I)

symmetricbraiding = BraidingStyle(sectortype(eltype(V))) isa SymmetricBraiding
println("---------------------------------------")
println("Mooncake with symmetry: $Istr")
println("---------------------------------------")
eltypes = (Float64,) # no complex support yet
symmetricbraiding && @timedtestset "TensorOperations with scalartype $T" for T in eltypes
atol = precision(T)
rtol = precision(T)

@timedtestset "tensorcontract!" begin
for _ in 1:5
d = 0
local V1, V2, V3
# retry a couple times to make sure there are at least some nonzero elements
for _ in 1:10
k1 = rand(0:3)
k2 = rand(0:2)
k3 = rand(0:2)
V1 = prod(v -> rand(Bool) ? v' : v, rand(V, k1); init = one(V[1]))
V2 = prod(v -> rand(Bool) ? v' : v, rand(V, k2); init = one(V[1]))
V3 = prod(v -> rand(Bool) ? v' : v, rand(V, k3); init = one(V[1]))
d = min(dim(V1 ← V2), dim(V1' ← V2), dim(V2 ← V3), dim(V2' ← V3))
d > 0 && break
end
ipA = randindextuple(length(V1) + length(V2))
pA = _repartition(invperm(linearize(ipA)), length(V1))
ipB = randindextuple(length(V2) + length(V3))
pB = _repartition(invperm(linearize(ipB)), length(V2))
pAB = randindextuple(length(V1) + length(V3))

α = randn(T)
β = randn(T)
V2_conj = prod(conj, V2; init = one(V[1]))

for conjA in (false, true), conjB in (false, true)
A = randn(T, permute(V1 ← (conjA ? V2_conj : V2), ipA))
B = randn(T, permute((conjB ? V2_conj : V2) ← V3, ipB))
C = randn!(
TensorOperations.tensoralloc_contract(
T, A, pA, conjA, B, pB, conjB, pAB, Val(false)
)
)
Mooncake.TestUtils.test_rule(
rng, tensorcontract!, C, A, pA, conjA, B, pB, conjB, pAB, α, β;
atol, rtol, mode, is_primitive
)

end
end
end
end
end
Loading