Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/src/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ When releasing a new version, move the "Unreleased" changes to a new version sec

### Fixed

- Eigenvalue decompositions of diagonal inputs are sorted and have the same type as non-diagonal inputs ([#151](https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/pull/151)

## [0.6.2](https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/compare/v0.6.1...v0.6.2) - 2026-01-08

### Added
Expand Down
42 changes: 28 additions & 14 deletions src/implementations/eig.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,23 +30,21 @@ end

function check_input(::typeof(eig_full!), A::AbstractMatrix, DV, ::DiagonalAlgorithm)
m, n = size(A)
@assert m == n && isdiag(A)
((m == n) && isdiag(A)) || throw(DimensionMismatch("diagonal input matrix expected"))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we print the actual dimensions here with a lazy string?

D, V = DV
@assert D isa Diagonal && V isa Diagonal
@assert D isa Diagonal && V isa AbstractMatrix
@check_size(D, (m, m))
@check_scalar(D, A, complex)
@check_size(V, (m, m))
# Diagonal doesn't need to promote to complex scalartype since we know it is diagonalizable
@check_scalar(D, A)
@check_scalar(V, A)
@check_scalar(V, A, complex)
return nothing
end
function check_input(::typeof(eig_vals!), A::AbstractMatrix, D, ::DiagonalAlgorithm)
m, n = size(A)
@assert m == n && isdiag(A)
((m == n) && isdiag(A)) || throw(DimensionMismatch("diagonal input matrix expected"))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above

@assert D isa AbstractVector
@check_size(D, (n,))
# Diagonal doesn't need to promote to complex scalartype since we know it is diagonalizable
@check_scalar(D, A)
@check_scalar(D, A, complex)
return nothing
end

Expand All @@ -70,10 +68,14 @@ function initialize_output(::Union{typeof(eig_trunc!), typeof(eig_trunc_no_error
end

function initialize_output(::typeof(eig_full!), A::Diagonal, ::DiagonalAlgorithm)
return A, similar(A)
T = eltype(A)
Tc = complex(T)
D = T <: Complex ? A : Diagonal(similar(A, Tc, size(A, 1)))
return D, similar(A, Tc, size(A))
end
function initialize_output(::typeof(eig_vals!), A::Diagonal, ::DiagonalAlgorithm)
return diagview(A)
T = eltype(A)
return T <: Complex ? diagview(A) : similar(A, complex(T), size(A, 1))
end

# Implementation
Expand Down Expand Up @@ -129,17 +131,29 @@ end

# Diagonal logic
# --------------
function eig_full!(A::Diagonal, (D, V)::Tuple{Diagonal, Diagonal}, alg::DiagonalAlgorithm)
check_input(eig_full!, A, (D, V), alg)
D === A || copy!(D, A)
one!(V)
eig_sortby(x::T) where {T <: Number} = T <: Complex ? (real(x), imag(x)) : x
function eig_full!(A::Diagonal, DV, alg::DiagonalAlgorithm)
check_input(eig_full!, A, DV, alg)
D, V = DV
diagA = diagview(A)
I = sortperm(diagA; by = eig_sortby)
if D === A
permute!(diagA, I)
else
diagview(D) .= view(diagA, I)
end
zero!(V)
n = size(A, 1)
I .+= (0:(n - 1)) .* n
V[I] .= Ref(one(eltype(V)))
return D, V
end

function eig_vals!(A::Diagonal, D::AbstractVector, alg::DiagonalAlgorithm)
check_input(eig_vals!, A, D, alg)
Ad = diagview(A)
D === Ad || copy!(D, Ad)
sort!(D; by = eig_sortby)
return D
end

Expand Down
4 changes: 2 additions & 2 deletions test/testsuite/eig.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ function test_eig_full(
return @testset "eig_full! $summary_str" begin
A = instantiate_matrix(T, sz)
Ac = deepcopy(A)
Tc = isa(A, Diagonal) ? eltype(T) : complex(eltype(T))
Tc = complex(eltype(T))
D, V = @testinferred eig_full(A)
@test eltype(D) == eltype(V) == Tc
@test A * V ≈ V * D
Expand All @@ -51,7 +51,7 @@ function test_eig_full_algs(
return @testset "eig_full! algorithm $alg $summary_str" for alg in algs
A = instantiate_matrix(T, sz)
Ac = deepcopy(A)
Tc = isa(A, Diagonal) ? eltype(T) : complex(eltype(T))
Tc = complex(eltype(T))
D, V = @testinferred eig_full(A; alg)
@test eltype(D) == eltype(V) == Tc
@test A * V ≈ V * D
Expand Down