Skip to content

Commit c5afadd

Browse files
committed
fix Ising example
1 parent ea40a3a commit c5afadd

12 files changed

Lines changed: 32 additions & 288 deletions

File tree

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
docs/build
2+
LocalPreferences.toml
23
Manifest.toml
34
.vscode
45
.DS_Store

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ authors = ["<Andreas Feuerpfeil|andreas.feuerpfeil@gmail.com>"]
66
[deps]
77
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
88
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
9+
MPIHelper = "0ff8cc61-2cc6-4ec9-9e2e-eaba226bed24"
910
MPIPreferences = "3da0fdf6-3ccc-4f1b-acd9-58baa6c99267"
1011
MPSKit = "bb1c41ca-d63c-52ed-829e-0820dda26502"
1112
MPSKitModels = "ca635005-6f8c-4cd1-b51d-8491250ef2ab"
@@ -19,6 +20,7 @@ VectorInterface = "409d34a3-91d5-4945-b6ec-7529ddf182d8"
1920
[compat]
2021
LinearAlgebra = "1.12.0"
2122
MPI = "0.20.23"
23+
MPIHelper = "1.0.0"
2224
MPIPreferences = "0.1.11"
2325
MPSKit = "0.13.8"
2426
MPSKitModels = "0.4.4"

examples/Ising.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,15 @@ H = heisenberg_XXX(symmetry, chain; J, spin);
1212
physical_space = SU2Space(1 => 1);
1313
virtual_space_inf = Rep[SU₂](1 // 2 => 16, 3 // 2 => 16, 5 // 2 => 8, 7 // 2 => 4);
1414
ψ₀_inf = InfiniteMPS([physical_space], [virtual_space_inf]);
15-
ψ_inf, envs_inf, delta_inf = find_groundstate(ψ₀_inf, 2*H; verbosity = 3);
16-
15+
ψ_inf, envs_inf, delta_inf = find_groundstate(ψ₀_inf, H; verbosity = 3);
1716

1817
using MPSKitParallel
19-
using MPSKitParallel: mpi_rank, mpi_size
2018
using MPI
2119
H_mpi = MPIOperator(H);
2220
MPI.Init()
21+
mpi_rank() = MPI.Comm_rank(MPI.COMM_WORLD)
22+
mpi_size() = MPI.Comm_size(MPI.COMM_WORLD)
2323
println("Hey, I am rank=$(mpi_rank()) out of $(mpi_size()) processes.")
2424
ψ_infmpi, envs_infmpi, delta_infmpi = find_groundstate(ψ₀_inf, H_mpi; verbosity = 3);
2525

26-
println("Hey, I am rank=$(mpi_rank()) out of $(mpi_size()) processes. abs(dot(ψ_inf, ψ_infmpi)) = $(abs(dot(ψ_inf, ψ_infmpi)))")
26+
println("Hey, I am rank=$(mpi_rank()) out of $(mpi_size()) processes. abs(dot(ψ_inf, ψ_infmpi)) = $(abs(dot(ψ_inf, ψ_infmpi)))")

src/MPIOperator/mpioperator.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
## This shallow struct is used to indicate that each LazyMIPOperator should be evaluated on each rank and the result is to be reduced across all ranks using MPI.Allreduce
2+
## This is the MPI-parallelized version of a linear operator
3+
## If one added the flexibilty of choosing the reduction, one could also parallelize over products of functions etc...
24
struct MPIOperator{O}
35
parent::O
46
function MPIOperator(parent::O) where {O}
@@ -15,7 +17,7 @@ end
1517

1618
function (Op::MPIOperator{O})(x::S) where {O,S}
1719
y_per_rank = parent(Op)(x)
18-
y = large_allreduce(y_per_rank, +, MPI.COMM_WORLD)
20+
y = MPIHelper.allreduce(y_per_rank, +, MPI.COMM_WORLD)
1921
return y
2022
end
2123

src/MPSKitParallel.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,18 @@ using MacroTools
1616
using LinearAlgebra
1717
using VectorInterface
1818

19+
using MPIHelper
20+
1921
import LinearAlgebra: norm
2022
import VectorInterface: scale
2123
import MPSKit: environments, AbstractMPSEnvironments, InfiniteEnvironments
2224
import MPSKit: C_hamiltonian, AC_hamiltonian, AC2_hamiltonian, C_projection, AC_projection, AC2_projection
2325
import MPSKit: exact_diagonalization
2426

27+
using MPSKit: IterativeSolver, VUMPSState, AbstractMPS, Multiline, eachsite, fixedpoint, regauge!, left_orth, left_orth!, transfer_leftenv!, transfer_rightenv!, svd_trunc!
28+
using MPSKit: AC, AC2, _transpose_front, _transpose_tail, _mul_tail, AC_hamiltonian, AC2_hamiltonian
29+
using MPSKit.DynamicTols: updatetol
30+
using Base.Threads: @spawn, @sync
2531

2632
include("includes.jl")
2733

src/algorithms/groundstate/idmrg.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@ function MPSKit._localupdate_sweep_idmrg!(ψ::AbstractMPS, H::MPIOperator, envs,
66
_, ψ.AC[pos] = fixedpoint(h, ψ.AC[pos], :SR, alg_eigsolve)
77
if pos == length(ψ)
88
# AC needed in next sweep
9-
ψ.AL[pos], ψ.C[pos] = mpi_left_orth(ψ.AC[pos])
9+
ψ.AL[pos], ψ.C[pos] = mpi_execute_on_root_and_bcast(left_orth,ψ.AC[pos])
1010
else
11-
ψ.AL[pos], ψ.C[pos] = mpi_left_orth!(ψ.AC[pos])
11+
ψ.AL[pos], ψ.C[pos] = mpi_execute_on_root_and_bcast(left_orth!,ψ.AC[pos])
1212
end
1313
transfer_leftenv!(envs, ψ, H, ψ, pos + 1)
1414
end
@@ -33,7 +33,7 @@ function MPSKit._localupdate_sweep_idmrg!(ψ::AbstractMPS, H::MPIOperator, envs,
3333
h_ac2 = AC2_hamiltonian(pos, ψ, H, ψ, envs)
3434
_, ac2′ = fixedpoint(h_ac2, ac2, :SR, alg_eigsolve)
3535

36-
al, c, ar = mpi_svd_trunc!(ac2′; trunc = alg.trscheme, alg = alg.alg_svd)
36+
al, c, ar = mpi_execute_on_root_and_bcast(svd_trunc!, ac2′; trunc = alg.trscheme, alg = alg.alg_svd)
3737
normalize!(c)
3838

3939
ψ.AL[pos] = al
@@ -52,7 +52,7 @@ function MPSKit._localupdate_sweep_idmrg!(ψ::AbstractMPS, H::MPIOperator, envs,
5252
h_ac2 = AC2_hamiltonian(0, ψ, H, ψ, envs)
5353
_, ac2′ = fixedpoint(h_ac2, ac2, :SR, alg_eigsolve)
5454

55-
al, c, ar = mpi_svd_trunc!(ac2′; trunc = alg.trscheme, alg = alg.alg_svd)
55+
al, c, ar = mpi_execute_on_root_and_bcast(svd_trunc!, ac2′; trunc = alg.trscheme, alg = alg.alg_svd)
5656
normalize!(c)
5757

5858
ψ.AL[end] = al
@@ -75,7 +75,7 @@ function MPSKit._localupdate_sweep_idmrg!(ψ::AbstractMPS, H::MPIOperator, envs,
7575
h_ac2 = AC2_hamiltonian(pos, ψ, H, ψ, envs)
7676
_, ac2′ = fixedpoint(h_ac2, ac2, :SR, alg_eigsolve)
7777

78-
al, c, ar = mpi_svd_trunc!(ac2′; trunc = alg.trscheme, alg = alg.alg_svd)
78+
al, c, ar = mpi_execute_on_root_and_bcast(svd_trunc!, ac2′; trunc = alg.trscheme, alg = alg.alg_svd)
7979
normalize!(c)
8080

8181
ψ.AL[pos] = al
@@ -94,7 +94,7 @@ function MPSKit._localupdate_sweep_idmrg!(ψ::AbstractMPS, H::MPIOperator, envs,
9494
ac2 = AC2(ψ, 0; kind = :ACAR)
9595
h_ac2 = AC2_hamiltonian(0, ψ, H, ψ, envs)
9696
_, ac2′ = fixedpoint(h_ac2, ac2, :SR, alg_eigsolve)
97-
al, c, ar = mpi_svd_trunc!(ac2′; trunc = alg.trscheme, alg = alg.alg_svd)
97+
al, c, ar = mpi_execute_on_root_and_bcast(svd_trunc!, ac2′; trunc = alg.trscheme, alg = alg.alg_svd)
9898
normalize!(c)
9999

100100
ψ.AL[end] = al

src/algorithms/groundstate/vumps.jl

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11

22
function MPSKit.localupdate_step!(
3-
it::IterativeSolver{<:VUMPS{S, MPIOperator{O}, E}}, state, scheduler = MPSKit.Defaults.scheduler[]
3+
it::IterativeSolver{<:VUMPS}, state::VUMPSState{S, MPIOperator{O}, E}, scheduler = MPSKit.Defaults.scheduler[]
44
) where {S, O, E}
55
alg_eigsolve = updatetol(it.alg_eigsolve, state.iter, state.ϵ)
66
alg_orth = MPSKit.Defaults.alg_qr()
@@ -32,7 +32,7 @@ function MPSKit._localupdate_vumps_step!(
3232
_, AC = fixedpoint(Hac, AC₀, which, alg_eigsolve)
3333
Hc = C_hamiltonian(site, mps, operator, mps, envs)
3434
_, C = fixedpoint(Hc, C₀, which, alg_eigsolve)
35-
return mpi_regauge!(AC, C; alg = alg_orth)
35+
return mpi_execute_on_root_and_bcast(regauge!, AC, C; alg = alg_orth)
3636
end
3737

3838
local AC, C
@@ -46,27 +46,28 @@ function MPSKit._localupdate_vumps_step!(
4646
_, C = fixedpoint(Hc, C₀, which, alg_eigsolve)
4747
end
4848
end
49-
return mpi_regauge!(AC, C; alg = alg_orth)
49+
return mpi_execute_on_root_and_bcast(regauge!, AC, C; alg = alg_orth)
5050
end
5151

52-
function MPSKit.gauge_step!(it::IterativeSolver{<:VUMPS{S, MPIOperator{O}, E}}, state, ACs::AbstractVector) where {S, O, E}
52+
function MPSKit.gauge_step!(it::IterativeSolver{<:VUMPS}, state::VUMPSState{S, MPIOperator{O}, E}, ACs::AbstractVector) where {S, O, E}
5353
alg_gauge = updatetol(it.alg_gauge, state.iter, state.ϵ)
54+
println("Gauging!")
5455
if mpi_is_root()
5556
psi = InfiniteMPS(ACs, state.mps.C[end]; alg_gauge.tol, alg_gauge.maxiter)
5657
else
5758
psi = nothing
5859
end
59-
psi = large_bcast(psi, 0, MPI.COMM_WORLD)
60+
psi = MPIHelper.bcast(psi, MPI.COMM_WORLD)
6061
return psi
6162
end
6263

63-
function MPSKit.gauge_step!(it::IterativeSolver{<:VUMPS{S, MPIOperator{O}, E}}, state, ACs::AbstractMatrix) where {S, O, E}
64+
function MPSKit.gauge_step!(it::IterativeSolver{<:VUMPS}, state::VUMPSState{S, MPIOperator{O}, E}, ACs::AbstractMatrix) where {S, O, E}
6465
alg_gauge = updatetol(it.alg_gauge, state.iter, state.ϵ)
6566
if mpi_is_root()
6667
psi = MultilineMPS(ACs, @view(state.mps.C[:, end]); alg_gauge.tol, alg_gauge.maxiter)
6768
else
6869
psi = nothing
6970
end
70-
psi = large_bcast(psi, 0, MPI.COMM_WORLD)
71+
psi = MPIHelper.bcast(psi, 0, MPI.COMM_WORLD)
7172
return psi
7273
end

src/includes.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
11
include("utility/forward.jl")
2-
include("multiprocessing/mpi/helper.jl")
3-
include("multiprocessing/mpi/mpi_buffers.jl")
42

53
include("MPIOperator/mpioperator.jl")
64
include("MPIOperator/derivatives.jl")
@@ -11,4 +9,5 @@ include("algorithms/expval.jl")
119
include("algorithms/ED.jl")
1210
include("algorithms/grassmann.jl")
1311

14-
include("SharedMPS/sharedmps.jl")
12+
include("algorithms/groundstate/vumps.jl")
13+
include("algorithms/groundstate/idmrg.jl")

src/multiprocessing/mpi/helper.jl

Lines changed: 0 additions & 48 deletions
This file was deleted.

0 commit comments

Comments
 (0)