Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
name = "TensorKit"
uuid = "07d1fe3e-3e46-537d-9eac-e9e13d0d4cec"
authors = ["Jutho Haegeman, Lukas Devos"]
version = "0.16.3"
authors = ["Jutho Haegeman, Lukas Devos"]

[deps]
Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4"
LRUCache = "8ac3fa9e-de4c-5943-b1dc-09c6b5f20637"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MatrixAlgebraKit = "6c742aac-3347-4629-af66-fc926824e5e4"
Expand Down Expand Up @@ -41,6 +42,7 @@ CUDA = "5.9"
ChainRulesCore = "1"
ChainRulesTestUtils = "1"
Combinatorics = "1"
Dictionaries = "0.4"
FiniteDifferences = "0.12"
GPUArrays = "11.3.1"
JET = "0.9, 0.10, 0.11"
Expand Down
2 changes: 2 additions & 0 deletions src/TensorKit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ const TO = TensorOperations

using MatrixAlgebraKit

using Dictionaries: Dictionaries, Dictionary, Indices, gettoken, gettokenvalue
using LRUCache
using OhMyThreads
using ScopedValues
Expand Down Expand Up @@ -218,6 +219,7 @@ end
# Definitions and methods for tensors
#-------------------------------------
# general definitions
include("tensors/tensorstructure.jl")
include("tensors/abstracttensor.jl")
include("tensors/backends.jl")
include("tensors/blockiterator.jl")
Expand Down
21 changes: 21 additions & 0 deletions src/auxiliary/dicts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -263,3 +263,24 @@ function Base.:(==)(d1::SortedVectorDict, d2::SortedVectorDict)
end
return true
end

"""
Hashed(value, hashfunction = Base.hash, isequal = Base.isequal)

Wrapper struct to alter the `hash` and `isequal` implementations of a given value.
This is useful in the contexts of dictionaries, where you either want to customize the hashfunction,
or consider various values as equal with a different notion of equality.
"""
struct Hashed{T, Hash, Eq}
val::T
end

Hashed(val, hash = Base.hash, eq = Base.isequal) = Hashed{typeof(val), hash, eq}(val)

Base.parent(h::Hashed) = h.val

# hash overload
Base.hash(h::Hashed{T, Hash}, seed::UInt) where {T, Hash} = Hash(parent(h), seed)

# isequal overload
Base.isequal(h1::H, h2::H) where {Eq, H <: Hashed{<:Any, <:Any, Eq}} = Eq(parent(h1), parent(h2))
148 changes: 7 additions & 141 deletions src/spaces/homspace.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,7 @@ spacetype(::Type{<:HomSpace{S}}) where {S} = S

const TensorSpace{S <: ElementarySpace} = Union{S, ProductSpace{S}}
const TensorMapSpace{S <: ElementarySpace, N₁, N₂} = HomSpace{
S, ProductSpace{S, N₁},
ProductSpace{S, N₂},
S, ProductSpace{S, N₁}, ProductSpace{S, N₂},
}

numout(::Type{TensorMapSpace{S, N₁, N₂}}) where {S, N₁, N₂} = N₁
Expand All @@ -62,17 +61,12 @@ end
→(dom::VectorSpace, codom::VectorSpace) = ←(codom, dom)

function Base.show(io::IO, W::HomSpace)
if length(W.codomain) == 1
print(io, W.codomain[1])
else
print(io, W.codomain)
end
print(io, " ← ")
return if length(W.domain) == 1
print(io, W.domain[1])
else
print(io, W.domain)
end
return print(
io,
numout(W) == 1 ? codomain(W)[1] : codomain(W),
" ← ",
numin(W) == 1 ? domain(W)[1] : domain(W)
)
end

"""
Expand Down Expand Up @@ -131,12 +125,6 @@ end

dims(W::HomSpace) = (dims(codomain(W))..., dims(domain(W))...)

"""
fusiontrees(W::HomSpace)

Return the fusiontrees corresponding to all valid fusion channels of a given `HomSpace`.
"""
fusiontrees(W::HomSpace) = fusionblockstructure(W).fusiontreelist

# Operations on HomSpaces
# -----------------------
Expand Down Expand Up @@ -290,125 +278,3 @@ function removeunit(P::HomSpace, ::Val{i}) where {i}
return codomain(P) ← removeunit(domain(P), Val(i - numout(P)))
end
end

# Block and fusion tree ranges: structure information for building tensors
#--------------------------------------------------------------------------

# sizes, strides, offset
const StridedStructure{N} = Tuple{NTuple{N, Int}, NTuple{N, Int}, Int}

struct FusionBlockStructure{I, N, F₁, F₂}
totaldim::Int
blockstructure::SectorDict{I, Tuple{Tuple{Int, Int}, UnitRange{Int}}}
fusiontreelist::Vector{Tuple{F₁, F₂}}
fusiontreestructure::Vector{StridedStructure{N}}
fusiontreeindices::FusionTreeDict{Tuple{F₁, F₂}, Int}
end

function fusionblockstructuretype(W::HomSpace)
N₁ = length(codomain(W))
N₂ = length(domain(W))
N = N₁ + N₂
I = sectortype(W)
F₁ = fusiontreetype(I, N₁)
F₂ = fusiontreetype(I, N₂)
return FusionBlockStructure{I, N, F₁, F₂}
end

@cached function fusionblockstructure(W::HomSpace)::fusionblockstructuretype(W)
codom = codomain(W)
dom = domain(W)
N₁ = length(codom)
N₂ = length(dom)
I = sectortype(W)
F₁ = fusiontreetype(I, N₁)
F₂ = fusiontreetype(I, N₂)

# output structure
blockstructure = SectorDict{I, Tuple{Tuple{Int, Int}, UnitRange{Int}}}() # size, range
fusiontreelist = Vector{Tuple{F₁, F₂}}()
fusiontreestructure = Vector{Tuple{NTuple{N₁ + N₂, Int}, NTuple{N₁ + N₂, Int}, Int}}() # size, strides, offset

# temporary data structures
splittingtrees = Vector{F₁}()
splittingstructure = Vector{Tuple{Int, Int}}()

# main computational routine
blockoffset = 0
for c in blocksectors(W)
empty!(splittingtrees)
empty!(splittingstructure)

offset₁ = 0
for f₁ in fusiontrees(codom, c)
push!(splittingtrees, f₁)
d₁ = dim(codom, f₁.uncoupled)
push!(splittingstructure, (offset₁, d₁))
offset₁ += d₁
end
blockdim₁ = offset₁
strides = (1, blockdim₁)

offset₂ = 0
for f₂ in fusiontrees(dom, c)
s₂ = f₂.uncoupled
d₂ = dim(dom, s₂)
for (f₁, (offset₁, d₁)) in zip(splittingtrees, splittingstructure)
push!(fusiontreelist, (f₁, f₂))
totaloffset = blockoffset + offset₂ * blockdim₁ + offset₁
subsz = (dims(codom, f₁.uncoupled)..., dims(dom, f₂.uncoupled)...)
@assert !any(isequal(0), subsz)
substr = _subblock_strides(subsz, (d₁, d₂), strides)
push!(fusiontreestructure, (subsz, substr, totaloffset))
end
offset₂ += d₂
end
blockdim₂ = offset₂
blocksize = (blockdim₁, blockdim₂)
blocklength = blockdim₁ * blockdim₂
blockrange = (blockoffset + 1):(blockoffset + blocklength)
blockoffset = last(blockrange)
blockstructure[c] = (blocksize, blockrange)
end

fusiontreeindices = sizehint!(
FusionTreeDict{Tuple{F₁, F₂}, Int}(), length(fusiontreelist)
)
for (i, f₁₂) in enumerate(fusiontreelist)
fusiontreeindices[f₁₂] = i
end
totaldim = blockoffset
structure = FusionBlockStructure(
totaldim, blockstructure, fusiontreelist, fusiontreestructure, fusiontreeindices
)
return structure
end

function _subblock_strides(subsz, sz, str)
sz_simplify = Strided.StridedViews._simplifydims(sz, str)
strides = Strided.StridedViews._computereshapestrides(subsz, sz_simplify...)
isnothing(strides) &&
throw(ArgumentError("unexpected error in computing subblock strides"))
return strides
end

function CacheStyle(::typeof(fusionblockstructure), W::HomSpace)
return GlobalLRUCache()
end

# Diagonal ranges
#----------------
# TODO: is this something we want to cache?
function diagonalblockstructure(W::HomSpace)
((numin(W) == numout(W) == 1) && domain(W) == codomain(W)) ||
throw(SpaceMismatch("Diagonal only support on V←V with a single space V"))
structure = SectorDict{sectortype(W), UnitRange{Int}}() # range
offset = 0
dom = domain(W)[1]
for c in blocksectors(W)
d = dim(dom, c)
structure[c] = offset .+ (1:d)
offset += d
end
return structure
end
2 changes: 1 addition & 1 deletion src/tensors/abstracttensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ hasblock(t::AbstractTensorMap, c::Sector) = c ∈ blocksectors(t)

Return an iterator over all splitting - fusion tree pairs of a tensor.
"""
fusiontrees(t::AbstractTensorMap) = fusionblockstructure(t).fusiontreelist
fusiontrees(t::AbstractTensorMap) = fusiontrees(space(t))

fusiontreetype(t::AbstractTensorMap) = fusiontreetype(typeof(t))
function fusiontreetype(::Type{T}) where {T <: AbstractTensorMap}
Expand Down
3 changes: 1 addition & 2 deletions src/tensors/braidingtensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,7 @@ function block(b::BraidingTensor, s::Sector)
structure = fusionblockstructure(b)
base_offset = first(structure.blockstructure[s][2]) - 1

for ((f1, f2), (sz, str, off)) in
zip(structure.fusiontreelist, structure.fusiontreestructure)
for ((f1, f2), (sz, str, off)) in pairs(fusiontreestructure(space(b)))
if (f1.uncoupled != reverse(f2.uncoupled)) || !(f1.coupled == f2.coupled == s)
continue
end
Expand Down
10 changes: 4 additions & 6 deletions src/tensors/tensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -487,13 +487,11 @@ end
function subblock(
t::TensorMap{T, S, N₁, N₂}, (f₁, f₂)::Tuple{FusionTree{I, N₁}, FusionTree{I, N₂}}
) where {T, S, N₁, N₂, I <: Sector}
structure = fusionblockstructure(t)
@boundscheck begin
haskey(structure.fusiontreeindices, (f₁, f₂)) || throw(SectorMismatch())
end
fts = fusiontreestructure(space(t))
found, token = gettoken(fts, (f₁, f₂))
@boundscheck found || throw(SectorMismatch(lazy"fusion tree pair ($(f₁, f₂)) is not present"))
@inbounds begin
i = structure.fusiontreeindices[(f₁, f₂)]
sz, str, offset = structure.fusiontreestructure[i]
sz, str, offset = gettokenvalue(fts, token)
return StridedView(t.data, sz, str, offset)
end
end
Expand Down
Loading
Loading