Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
b4ae4d7
Working BP Commit
JoeyT1994 Oct 2, 2025
d77d063
BP Code
JoeyT1994 Oct 23, 2025
b80e36e
Express BP in terms of `SweepIterator` interface
jack-dunham Oct 28, 2025
fe44b80
Add method for `setmessages!` that allows messages from one cache to …
jack-dunham Oct 31, 2025
3ce0898
Network is now passed to `forest_cover_edge_sequence` directly.
jack-dunham Nov 10, 2025
f6e4fd0
test file formatting
jack-dunham Nov 25, 2025
63840a9
Add `DataGraphsPartitionedGraphsExt` glue for `TensorNetwork` type
jack-dunham Nov 25, 2025
ba22ab5
Make abstract tensor network interface more generic.
jack-dunham Nov 25, 2025
49b0870
BP Caching overhauls
jack-dunham Nov 25, 2025
db46c04
Remove dead deps
jack-dunham Nov 25, 2025
400e373
Fix merge
jack-dunham Nov 25, 2025
b9aafe8
Fix type inference in TensorNetwork construction
jack-dunham Nov 25, 2025
4090e61
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 25, 2025
be0750e
Remove `ITensorBase` dep
jack-dunham Nov 25, 2025
b971b89
`forest_cover_edge_sequence` now constructs a temporary `NamedGraph` …
jack-dunham Dec 1, 2025
9ebf031
[LazyNamedDimsArrays] Fix `parenttype` method
jack-dunham Jan 6, 2026
16fe303
BP Cache now uses new `DataGraphs`interface
jack-dunham Jan 6, 2026
24a4335
Adjust `default_message` to take a `message` type as its first argument
jack-dunham Jan 6, 2026
c43884e
Remove unnecessary code and fix ambiguities in `AbstractTensorNetwork`
jack-dunham Jan 6, 2026
dd6f645
`TensorNetwork` type now uses new DataGraphs interface
jack-dunham Jan 6, 2026
7bb579c
Sweeping algorithms based on AlgorithmsInterface.jl (#30)
mtfishman Dec 19, 2025
032447a
Upgrade to NamedDimsArrays.jl v0.11 (#38)
mtfishman Dec 23, 2025
b256d79
[LazyNamedDimsArrays] New `symnameddims` method that pulls out indice…
jack-dunham Jan 9, 2026
b2da9d8
The function `region_scalar` should now return a scalar, rather than …
jack-dunham Jan 9, 2026
8506e26
Fix double counting in `edge_scalars` function
jack-dunham Jan 9, 2026
938180a
Minor code formatting
jack-dunham Jan 9, 2026
4461967
Expressed belief propagation in terms of AlgorithmsInterface
jack-dunham Jan 9, 2026
d68860a
Fixes to TensorNetwork construction from tensor list
jack-dunham Jan 9, 2026
2f5c783
Minor simplifications to `contract_network` interface.
jack-dunham Jan 9, 2026
bc32491
Merge branch 'main' into bp
jack-dunham Jan 9, 2026
36d168d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 9, 2026
88c70fd
Fix merge issue
jack-dunham Jan 9, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/ITensorNetworksNext.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,8 @@ include("contract_network.jl")
include("sweeping/utils.jl")
include("sweeping/eigenproblem.jl")

include("beliefpropagation/abstractbeliefpropagationcache.jl")
include("beliefpropagation/beliefpropagationcache.jl")
include("beliefpropagation/beliefpropagationproblem.jl")

end
2 changes: 1 addition & 1 deletion src/LazyNamedDimsArrays/lazynameddimsarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ using WrappedUnions: @wrapped
union::Union{A, Mul{LazyNamedDimsArray{T, A}}}
end

parenttype(::Type{LazyNamedDimsArray{<:Any, A}}) where {A} = A
parenttype(::Type{LazyNamedDimsArray{T, A}}) where {T, A} = A
parenttype(::Type{LazyNamedDimsArray{T}}) where {T} = AbstractNamedDimsArray{T}
parenttype(::Type{LazyNamedDimsArray}) = AbstractNamedDimsArray

Expand Down
3 changes: 3 additions & 0 deletions src/LazyNamedDimsArrays/symbolicnameddimsarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ const SymbolicNamedDimsArray{T, N, Parent <: SymbolicArray{T, N}, DimNames} =
function symnameddims(name, dims)
return lazy(nameddims(SymbolicArray(name, denamed.(dims)), dims))
end
function symnameddims(name, ndarray::AbstractNamedDimsArray)
return symnameddims(name, Tuple(inds(ndarray)))
end
symnameddims(name) = symnameddims(name, ())
using AbstractTrees: AbstractTrees
function AbstractTrees.printnode(io::IO, a::SymbolicNamedDimsArray)
Expand Down
103 changes: 49 additions & 54 deletions src/abstracttensornetwork.jl
Original file line number Diff line number Diff line change
@@ -1,27 +1,31 @@
using Adapt: Adapt, adapt, adapt_structure
using BackendSelection: @Algorithm_str, Algorithm
using DataGraphs: DataGraphs, AbstractDataGraph, edge_data, underlying_graph,
underlying_graph_type, vertex_data
underlying_graph_type, vertex_data, set_vertex_data!
using Dictionaries: Dictionary
using Graphs: Graphs, AbstractEdge, AbstractGraph, Graph, add_edge!, add_vertex!,
bfs_tree, center, dst, edges, edgetype, ne, neighbors, nv, rem_edge!, src, vertices
using LinearAlgebra: LinearAlgebra, factorize
using MacroTools: @capture
using NamedDimsArrays: dimnames, inds
using NamedGraphs: NamedGraphs, NamedGraph, not_implemented, steiner_tree
using NamedGraphs.GraphsExtensions: ⊔, directed_graph, incident_edges, rem_edges!,
rename_vertices, vertextype
using NamedGraphs.OrdinalIndexing: OrdinalSuffixedInteger
using NamedGraphs.GraphsExtensions:
⊔,
directed_graph,
incident_edges,
rem_edges!,
rename_vertices,
vertextype
using SplitApplyCombine: flatten
using NamedGraphs.SimilarType: similar_type

abstract type AbstractTensorNetwork{V, VD} <: AbstractDataGraph{V, VD, Nothing} end

function Graphs.rem_edge!(tn::AbstractTensorNetwork, e)
rem_edge!(underlying_graph(tn), e)
return tn
end
# Need to be careful about removing edges from tensor networks in case there is a bond
Graphs.rem_edge!(::AbstractTensorNetwork, edge) = not_implemented()

# TODO: Define a generic fallback for `AbstractDataGraph`?
DataGraphs.edge_data_eltype(::Type{<:AbstractTensorNetwork}) = error("No edge data")
DataGraphs.edge_data_eltype(::Type{<:AbstractTensorNetwork}) = not_implemented()

# Graphs.jl overloads
function Graphs.weights(graph::AbstractTensorNetwork)
Expand All @@ -36,7 +40,7 @@ function Graphs.weights(graph::AbstractTensorNetwork)
end

# Copy
Base.copy(tn::AbstractTensorNetwork) = error("Not implemented")
Base.copy(::AbstractTensorNetwork) = not_implemented()

# Iteration
Base.iterate(tn::AbstractTensorNetwork, args...) = iterate(vertex_data(tn), args...)
Expand All @@ -49,20 +53,7 @@ Base.eltype(tn::AbstractTensorNetwork) = eltype(vertex_data(tn))
# Overload if needed
Graphs.is_directed(::Type{<:AbstractTensorNetwork}) = false

# Derived interface, may need to be overloaded
function DataGraphs.underlying_graph_type(G::Type{<:AbstractTensorNetwork})
return underlying_graph_type(data_graph_type(G))
end

# AbstractDataGraphs overloads
function DataGraphs.vertex_data(graph::AbstractTensorNetwork, args...)
return error("Not implemented")
end
function DataGraphs.edge_data(graph::AbstractTensorNetwork, args...)
return error("Not implemented")
end

DataGraphs.underlying_graph(tn::AbstractTensorNetwork) = error("Not implemented")
DataGraphs.underlying_graph(::AbstractTensorNetwork) = not_implemented()
function NamedGraphs.vertex_positions(tn::AbstractTensorNetwork)
return NamedGraphs.vertex_positions(underlying_graph(tn))
end
Expand All @@ -81,49 +72,46 @@ function Adapt.adapt_structure(to, tn::AbstractTensorNetwork)
return map_vertex_data_preserve_graph(adapt(to), tn)
end

function linkinds(tn::AbstractTensorNetwork, edge::Pair)
return linkinds(tn, edgetype(tn)(edge))
end
function linkinds(tn::AbstractTensorNetwork, edge::AbstractEdge)
return inds(tn[src(edge)]) ∩ inds(tn[dst(edge)])
end
function linkaxes(tn::AbstractTensorNetwork, edge::Pair)
linkinds(tn::AbstractGraph, edge::Pair) = linkinds(tn, edgetype(tn)(edge))
linkinds(tn::AbstractGraph, edge::AbstractEdge) = inds(tn[src(edge)]) ∩ inds(tn[dst(edge)])

function linkaxes(tn::AbstractGraph, edge::Pair)
return linkaxes(tn, edgetype(tn)(edge))
end
function linkaxes(tn::AbstractTensorNetwork, edge::AbstractEdge)
function linkaxes(tn::AbstractGraph, edge::AbstractEdge)
return axes(tn[src(edge)]) ∩ axes(tn[dst(edge)])
end
function linknames(tn::AbstractTensorNetwork, edge::Pair)
function linknames(tn::AbstractGraph, edge::Pair)
return linknames(tn, edgetype(tn)(edge))
end
function linknames(tn::AbstractTensorNetwork, edge::AbstractEdge)
function linknames(tn::AbstractGraph, edge::AbstractEdge)
return dimnames(tn[src(edge)]) ∩ dimnames(tn[dst(edge)])
end

function siteinds(tn::AbstractTensorNetwork, v)
function siteinds(tn::AbstractGraph, v)
s = inds(tn[v])
for v′ in neighbors(tn, v)
s = setdiff(s, inds(tn[v′]))
end
return s
end
function siteaxes(tn::AbstractTensorNetwork, edge::AbstractEdge)
function siteaxes(tn::AbstractGraph, edge::AbstractEdge)
s = axes(tn[src(edge)]) ∩ axes(tn[dst(edge)])
for v′ in neighbors(tn, v)
s = setdiff(s, axes(tn[v′]))
end
return s
end
function sitenames(tn::AbstractTensorNetwork, edge::AbstractEdge)
function sitenames(tn::AbstractGraph, edge::AbstractEdge)
s = dimnames(tn[src(edge)]) ∩ dimnames(tn[dst(edge)])
for v′ in neighbors(tn, v)
s = setdiff(s, dimnames(tn[v′]))
end
return s
end

function setindex_preserve_graph!(tn::AbstractTensorNetwork, value, vertex)
vertex_data(tn)[vertex] = value
function setindex_preserve_graph!(tn::AbstractGraph, value, vertex)
set_vertex_data!(tn, value, vertex)
return tn
end

Expand Down Expand Up @@ -153,15 +141,15 @@ end

# Update the graph of the TensorNetwork `tn` to include
# edges that should exist based on the tensor connectivity.
function add_missing_edges!(tn::AbstractTensorNetwork)
function add_missing_edges!(tn::AbstractGraph)
foreach(v -> add_missing_edges!(tn, v), vertices(tn))
return tn
end

# Update the graph of the TensorNetwork `tn` to include
# edges that should be incident to the vertex `v`
# based on the tensor connectivity.
function add_missing_edges!(tn::AbstractTensorNetwork, v)
function add_missing_edges!(tn::AbstractGraph, v)
for v′ in vertices(tn)
if v ≠ v′
e = v => v′
Expand All @@ -175,13 +163,13 @@ end

# Fix the edges of the TensorNetwork `tn` to match
# the tensor connectivity.
function fix_edges!(tn::AbstractTensorNetwork)
function fix_edges!(tn::AbstractGraph)
foreach(v -> fix_edges!(tn, v), vertices(tn))
return tn
end
# Fix the edges of the TensorNetwork `tn` to match
# the tensor connectivity at vertex `v`.
function fix_edges!(tn::AbstractTensorNetwork, v)
function fix_edges!(tn::AbstractGraph, v)
rem_edges!(tn, incident_edges(tn, v))
add_missing_edges!(tn, v)
return tn
Expand Down Expand Up @@ -215,28 +203,20 @@ function Base.setindex!(tn::AbstractTensorNetwork, value, v)
fix_edges!(tn, v)
return tn
end
using NamedGraphs.OrdinalIndexing: OrdinalSuffixedInteger
# Fix ambiguity error.
function Base.setindex!(graph::AbstractTensorNetwork, value, vertex::OrdinalSuffixedInteger)
graph[vertices(graph)[vertex]] = value
return graph
end
# Fix ambiguity error.
function Base.setindex!(tn::AbstractTensorNetwork, value, edge::AbstractEdge)
return error("No edge data.")
end
# Fix ambiguity error.
function Base.setindex!(tn::AbstractTensorNetwork, value, edge::Pair)
return error("No edge data.")
end
using NamedGraphs.OrdinalIndexing: OrdinalSuffixedInteger
Base.setindex!(tn::AbstractTensorNetwork, value, edge::AbstractEdge) = not_implemented()
Base.setindex!(tn::AbstractTensorNetwork, value, edge::Pair) = not_implemented()
# Fix ambiguity error.
function Base.setindex!(
tn::AbstractTensorNetwork,
value,
edge::Pair{<:OrdinalSuffixedInteger, <:OrdinalSuffixedInteger},
)
return error("No edge data.")
return not_implemented()
end

function Base.show(io::IO, mime::MIME"text/plain", graph::AbstractTensorNetwork)
Expand All @@ -255,3 +235,18 @@ function Base.show(io::IO, mime::MIME"text/plain", graph::AbstractTensorNetwork)
end

Base.show(io::IO, graph::AbstractTensorNetwork) = show(io, MIME"text/plain"(), graph)

function Graphs.induced_subgraph(graph::AbstractTensorNetwork{V}, subvertices::Vector{V}) where {V}
return tensornetwork_induced_subgraph(graph, subvertices)
end

function tensornetwork_induced_subgraph(graph, subvertices)
underlying_subgraph, vlist = Graphs.induced_subgraph(underlying_graph(graph), subvertices)
subgraph = similar_type(graph)(underlying_subgraph)
for v in vertices(subgraph)
if isassigned(graph, v)
set!(vertex_data(subgraph), v, graph[v])
end
end
return subgraph, vlist
end
139 changes: 139 additions & 0 deletions src/beliefpropagation/abstractbeliefpropagationcache.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
using Graphs: AbstractGraph, AbstractEdge
using DataGraphs: AbstractDataGraph, edge_data, vertex_data, edge_data_eltype
using NamedGraphs.GraphsExtensions: boundary_edges
using NamedGraphs.PartitionedGraphs: QuotientView, QuotientEdge, parent

messages(bp_cache::AbstractGraph) = edge_data(bp_cache)
messages(bp_cache::AbstractGraph, edges) = [message(bp_cache, e) for e in edges]

function message(bp_cache::AbstractGraph, edge::AbstractEdge)
ms = messages(bp_cache)
return get!(ms, edge, default_message(bp_cache, edge))
end

deletemessage!(bp_cache::AbstractGraph, edge) = not_implemented()
function deletemessage!(bp_cache::AbstractDataGraph, edge)
ms = messages(bp_cache)
delete!(ms, edge)
return bp_cache
end

function deletemessages!(bp_cache::AbstractGraph, edges = edges(bp_cache))
for e in edges
deletemessage!(bp_cache, e)
end
return bp_cache
end

setmessage!(bp_cache::AbstractGraph, edge, message) = not_implemented()
function setmessage!(bp_cache::AbstractDataGraph, edge, message)
setindex!(bp_cache, message, edge)
return bp_cache
end
function setmessage!(bp_cache::QuotientView, edge, message)
setmessages!(parent(bp_cache), QuotientEdge(edge), message)
return bp_cache
end

function setmessages!(bp_cache::AbstractGraph, edge::QuotientEdge, message)
for e in edges(bp_cache, edge)
setmessage!(parent(bp_cache), e, message[e])
end
return bp_cache
end
function setmessages!(bpc_dst::AbstractGraph, bpc_src::AbstractGraph, edges)
for e in edges
setmessage!(bpc_dst, e, message(bpc_src, e))
end
return bpc_dst
end

factors(bpc::AbstractGraph) = vertex_data(bpc)
factors(bpc::AbstractGraph, vertices::Vector) = [factor(bpc, v) for v in vertices]
factors(bpc::AbstractGraph{V}, vertex::V) where {V} = factors(bpc, V[vertex])

factor(bpc::AbstractGraph, vertex) = factors(bpc)[vertex]

setfactor!(bpc::AbstractGraph, vertex, factor) = not_implemented()
function setfactor!(bpc::AbstractDataGraph, vertex, factor)
fs = factors(bpc)
setindex!(fs, vertex, factor)
return bpc
end

function region_scalar(bp_cache::AbstractGraph, edge::AbstractEdge)
return (message(bp_cache, edge) * message(bp_cache, reverse(edge)))[]
end

function region_scalar(bp_cache::AbstractGraph, vertex)

messages = incoming_messages(bp_cache, vertex)
state = factors(bp_cache, vertex)

return (reduce(*, messages) * reduce(*, state))[]
end

message_type(bpc::AbstractGraph) = message_type(typeof(bpc))
message_type(G::Type{<:AbstractGraph}) = eltype(Base.promote_op(messages, G))
message_type(type::Type{<:AbstractDataGraph}) = edge_data_eltype(type)

function vertex_scalars(bp_cache::AbstractGraph, vertices = vertices(bp_cache))
return map(v -> region_scalar(bp_cache, v), vertices)
end

function edge_scalars(bp_cache::AbstractGraph, edges = edges(undirected_graph(underlying_graph(bp_cache))))
return map(e -> region_scalar(bp_cache, e), edges)
end

function scalar_factors_quotient(bp_cache::AbstractGraph)
return vertex_scalars(bp_cache), edge_scalars(bp_cache)
end

function incoming_messages(bp_cache::AbstractGraph, vertices; ignore_edges = [])
b_edges = boundary_edges(bp_cache, [vertices;]; dir = :in)
b_edges = !isempty(ignore_edges) ? setdiff(b_edges, ignore_edges) : b_edges
return messages(bp_cache, b_edges)
end

default_messages(::AbstractGraph) = not_implemented()

#Adapt interface for changing device
map_messages(f, bp_cache, es = edges(bp_cache)) = map_messages!(f, copy(bp_cache), es)
function map_messages!(f, bp_cache, es = edges(bp_cache))
for e in es
setmessage!(bp_cache, e, f(message(bp_cache, e)))
end
return bp_cache
end

map_factors(f, bp_cache, vs = vertices(bp_cache)) = map_factors!(f, copy(bp_cache), vs)
function map_factors!(f, bp_cache, vs = vertices(bp_cache))
for v in vs
setfactor!(bp_cache, v, f(factor(bp_cache, v)))
end
return bp_cache
end

adapt_messages(to, bp_cache, es = edges(bp_cache)) = map_messages(adapt(to), bp_cache, es)
adapt_factors(to, bp_cache, vs = vertices(bp_cache)) = map_factors(adapt(to), bp_cache, vs)

abstract type AbstractBeliefPropagationCache{V, ED} <: AbstractDataGraph{V, Nothing, ED} end

function free_energy(bp_cache::AbstractBeliefPropagationCache)

numerator_terms, denominator_terms = scalar_factors_quotient(bp_cache)

if any(t -> real(t) < 0, numerator_terms)
numerator_terms = complex.(numerator_terms)
end
if any(t -> real(t) < 0, denominator_terms)
denominator_terms = complex.(denominator_terms)
end

if any(iszero, denominator_terms)
return -Inf
end

return sum(log.(numerator_terms)) - sum(log.((denominator_terms)))
end
partitionfunction(bp_cache::AbstractBeliefPropagationCache) = exp(free_energy(bp_cache))
Loading
Loading