Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 8 additions & 10 deletions src/implementations/eig.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@ copy_input(::Union{typeof(eig_trunc), typeof(eig_trunc_no_error)}, A) = copy_inp
copy_input(::typeof(eig_full), A::Diagonal) = copy(A)

function check_input(::typeof(eig_full!), A::AbstractMatrix, DV, ::AbstractAlgorithm)
m, n = size(A)
m == n || throw(DimensionMismatch("square input matrix expected"))
m = LinearAlgebra.checksquare(A)
D, V = DV
@assert D isa Diagonal && V isa AbstractMatrix
@check_size(D, (m, m))
Expand All @@ -20,17 +19,16 @@ function check_input(::typeof(eig_full!), A::AbstractMatrix, DV, ::AbstractAlgor
return nothing
end
function check_input(::typeof(eig_vals!), A::AbstractMatrix, D, ::AbstractAlgorithm)
m, n = size(A)
m == n || throw(DimensionMismatch("square input matrix expected"))
m = LinearAlgebra.checksquare(A)
@assert D isa AbstractVector
@check_size(D, (n,))
@check_size(D, (m,))
@check_scalar(D, A, complex)
return nothing
end

function check_input(::typeof(eig_full!), A::AbstractMatrix, DV, ::DiagonalAlgorithm)
m, n = size(A)
((m == n) && isdiag(A)) || throw(DimensionMismatch("diagonal input matrix expected"))
m = LinearAlgebra.checksquare(A)
isdiag(A) || throw(DimensionMismatch("diagonal input matrix expected"))
D, V = DV
@assert D isa Diagonal && V isa AbstractMatrix
@check_size(D, (m, m))
Expand All @@ -40,10 +38,10 @@ function check_input(::typeof(eig_full!), A::AbstractMatrix, DV, ::DiagonalAlgor
return nothing
end
function check_input(::typeof(eig_vals!), A::AbstractMatrix, D, ::DiagonalAlgorithm)
m, n = size(A)
((m == n) && isdiag(A)) || throw(DimensionMismatch("diagonal input matrix expected"))
m = LinearAlgebra.checksquare(A)
isdiag(A) || throw(DimensionMismatch("diagonal input matrix expected"))
@assert D isa AbstractVector
@check_size(D, (n,))
@check_size(D, (m,))
@check_scalar(D, A, complex)
return nothing
end
Expand Down
3 changes: 1 addition & 2 deletions src/implementations/eigh.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@ copy_input(::typeof(eigh_full), A::Diagonal) = copy(A)
check_hermitian(A, ::AbstractAlgorithm) = check_hermitian(A)
check_hermitian(A, alg::Algorithm) = check_hermitian(A; atol = get(alg.kwargs, :hermitian_tol, default_hermitian_tol(A)))
function check_hermitian(A; atol::Real = default_hermitian_tol(A), rtol::Real = 0)
m, n = size(A)
m == n || throw(DimensionMismatch("square input matrix expected"))
LinearAlgebra.checksquare(A)
ishermitian(A; atol, rtol) ||
throw(DomainError(A, "Hermitian matrix was expected. Use `project_hermitian` to project onto the nearest hermitian matrix."))
return nothing
Expand Down
19 changes: 7 additions & 12 deletions src/implementations/gen_eig.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,9 @@ end
copy_input(::typeof(gen_eig_vals), A, B) = copy_input(gen_eig_full, A, B)

function check_input(::typeof(gen_eig_full!), A::AbstractMatrix, B::AbstractMatrix, WV, ::AbstractAlgorithm)
ma, na = size(A)
mb, nb = size(B)
ma == na || throw(DimensionMismatch("square input matrix A expected"))
mb == nb || throw(DimensionMismatch("square input matrix B expected"))
ma == mb || throw(DimensionMismatch("first dimension of input matrices expected to match"))
na == nb || throw(DimensionMismatch("second dimension of input matrices expected to match"))
ma = LinearAlgebra.checksquare(A)
mb = LinearAlgebra.checksquare(B)
ma == mb || throw(DimensionMismatch(lazy"Expected matching input sizes, dimensions are $ma and $mb"))
W, V = WV
@assert W isa Diagonal && V isa AbstractMatrix
@check_size(W, (ma, ma))
Expand All @@ -23,13 +20,11 @@ function check_input(::typeof(gen_eig_full!), A::AbstractMatrix, B::AbstractMatr
return nothing
end
function check_input(::typeof(gen_eig_vals!), A::AbstractMatrix, B::AbstractMatrix, W, ::AbstractAlgorithm)
ma, na = size(A)
mb, nb = size(B)
ma == na || throw(DimensionMismatch("square input matrix A expected"))
mb == nb || throw(DimensionMismatch("square input matrix B expected"))
ma == mb || throw(DimensionMismatch("dimension of input matrices expected to match"))
ma = LinearAlgebra.checksquare(A)
mb = LinearAlgebra.checksquare(B)
ma == mb || throw(DimensionMismatch(lazy"Expected matching input sizes, dimensions are $ma and $mb"))
@assert W isa AbstractVector
@check_size(W, (na,))
@check_size(W, (ma,))
@check_scalar(W, A, complex)
@check_scalar(W, B, complex)
return nothing
Expand Down
10 changes: 4 additions & 6 deletions src/implementations/schur.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,21 @@ copy_input(::typeof(schur_vals), A) = copy_input(eig_vals, A)

# check input
function check_input(::typeof(schur_full!), A::AbstractMatrix, TZv, ::AbstractAlgorithm)
m, n = size(A)
m == n || throw(DimensionMismatch("square input matrix expected"))
m = LinearAlgebra.checksquare(A)
T, Z, vals = TZv
@assert T isa AbstractMatrix && Z isa AbstractMatrix && vals isa AbstractVector
@check_size(T, (m, m))
@check_scalar(T, A)
@check_size(Z, (m, m))
@check_scalar(Z, A)
@check_size(vals, (n,))
@check_size(vals, (m,))
@check_scalar(vals, A, complex)
return nothing
end
function check_input(::typeof(schur_vals!), A::AbstractMatrix, vals, ::AbstractAlgorithm)
m, n = size(A)
m == n || throw(DimensionMismatch("square input matrix expected"))
m = LinearAlgebra.checksquare(A)
@assert vals isa AbstractVector
@check_size(vals, (n,))
@check_size(vals, (m,))
@check_scalar(vals, A, complex)
return nothing
end
Expand Down