Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/interface/orthnull.jl
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,9 @@ left_orth_alg(alg::LeftOrthAlgorithm) = alg
left_orth_alg(alg::QRAlgorithms) = LeftOrthViaQR(alg)
left_orth_alg(alg::PolarAlgorithms) = LeftOrthViaPolar(alg)
left_orth_alg(alg::SVDAlgorithms) = LeftOrthViaSVD(alg)
left_orth_alg(alg::DiagonalAlgorithm) = LeftOrthViaQR(alg)
left_orth_alg(alg::TruncatedAlgorithm{<:SVDAlgorithms}) = LeftOrthViaSVD(alg)
left_orth_alg(alg::TruncatedAlgorithm{DiagonalAlgorithm}) = LeftOrthViaSVD(alg)

"""
right_orth_alg(alg::AbstractAlgorithm) -> RightOrthAlgorithm
Expand Down Expand Up @@ -478,7 +480,9 @@ right_orth_alg(alg::RightOrthAlgorithm) = alg
right_orth_alg(alg::LQAlgorithms) = RightOrthViaLQ(alg)
right_orth_alg(alg::PolarAlgorithms) = RightOrthViaPolar(alg)
right_orth_alg(alg::SVDAlgorithms) = RightOrthViaSVD(alg)
right_orth_alg(alg::DiagonalAlgorithm) = RightOrthViaLQ(alg)
right_orth_alg(alg::TruncatedAlgorithm{<:SVDAlgorithms}) = RightOrthViaSVD(alg)
right_orth_alg(alg::TruncatedAlgorithm{DiagonalAlgorithm}) = RightOrthViaSVD(alg)

"""
left_null_alg(alg::AbstractAlgorithm) -> LeftNullAlgorithm
Expand Down
6 changes: 3 additions & 3 deletions test/orthnull.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,16 @@ for T in (BLASFloats..., GenericFloats...), n in (37, m, 63)
if T ∈ BLASFloats
if CUDA.functional()
TestSuite.test_orthnull(CuMatrix{T}, (m, n); test_nullity = false)
n == m && TestSuite.test_orthnull(Diagonal{T, CuVector{T}}, m; test_orthnull = false)
n == m && TestSuite.test_orthnull(Diagonal{T, CuVector{T}}, m)
end
if AMDGPU.functional()
TestSuite.test_orthnull(ROCMatrix{T}, (m, n); test_nullity = false)
n == m && TestSuite.test_orthnull(Diagonal{T, ROCVector{T}}, m; test_orthnull = false)
n == m && TestSuite.test_orthnull(Diagonal{T, ROCVector{T}}, m)
end
end
if !is_buildkite
TestSuite.test_orthnull(T, (m, n))
AT = Diagonal{T, Vector{T}}
TestSuite.test_orthnull(AT, m; test_orthnull = false)
TestSuite.test_orthnull(AT, m)
end
end
2 changes: 2 additions & 0 deletions test/testsuite/TestSuite.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,11 @@ is_pivoted(alg::MatrixAlgebraKit.LQViaTransposedQR) = is_pivoted(alg.qr_alg)
isleftcomplete(V, N) = V * V' + N * N' ≈ I
isleftcomplete(V::AnyCuMatrix, N::AnyCuMatrix) = isleftcomplete(collect(V), collect(N))
isleftcomplete(V::AnyROCMatrix, N::AnyROCMatrix) = isleftcomplete(collect(V), collect(N))
isleftcomplete(V::Diagonal{T, <:ROCVector{<:T}}, N::Diagonal{T, <:ROCVector{<:T}}) where {T} = isleftcomplete(collect(V), collect(N))
isrightcomplete(Vᴴ, Nᴴ) = Vᴴ' * Vᴴ + Nᴴ' * Nᴴ ≈ I
isrightcomplete(V::AnyCuMatrix, N::AnyCuMatrix) = isrightcomplete(collect(V), collect(N))
isrightcomplete(V::AnyROCMatrix, N::AnyROCMatrix) = isrightcomplete(collect(V), collect(N))
isrightcomplete(V::Diagonal{T, <:ROCVector{<:T}}, N::Diagonal{T, <:ROCVector{<:T}}) where {T} = isrightcomplete(collect(V), collect(N))

instantiate_unitary(T, A, sz) = qr_compact(randn!(similar(A, eltype(T), sz, sz)))[1]
# AMDGPU can't generate ComplexF32 random numbers
Expand Down
8 changes: 4 additions & 4 deletions test/testsuite/orthnull.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@ _right_orth_lq!(x, CVᴴ; kwargs...) = right_orth!(x, CVᴴ; alg = :lq, kwargs..
_right_orth_polar(x; kwargs...) = right_orth(x; alg = :polar, kwargs...)
_right_orth_polar!(x, CVᴴ; kwargs...) = right_orth!(x, CVᴴ; alg = :polar, kwargs...)

function test_orthnull(T::Type, sz; test_nullity = true, test_orthnull = true, kwargs...)
function test_orthnull(T::Type, sz; test_nullity = true, kwargs...)
summary_str = testargs_summary(T, sz)
return @testset "orthnull $summary_str" begin
test_orthnull && test_left_orthnull(T, sz; kwargs...)
test_left_orthnull(T, sz; kwargs...)
test_nullity && test_left_nullity(T, sz; kwargs...)
test_orthnull && test_right_orthnull(T, sz; kwargs...)
test_right_orthnull(T, sz; kwargs...)
test_nullity && test_right_nullity(T, sz; kwargs...)
end
end
Expand Down Expand Up @@ -276,9 +276,9 @@ function test_right_orthnull(

# passing an algorithm
C, Vᴴ = @testinferred right_orth(A; alg = MatrixAlgebraKit.default_lq_algorithm(A))
Nᴴ = @testinferred right_null(A; alg = :lq, positive = true)
@test C isa typeof(A) && size(C) == (m, minmn)
@test Vᴴ isa typeof(A) && size(Vᴴ) == (minmn, n)
Nᴴ = @testinferred right_null(A; alg = :lq, positive = true)
@test eltype(Nᴴ) == eltype(A) && size(Nᴴ) == (n - minmn, n)
@test C * Vᴴ ≈ A
@test isisometric(Vᴴ; side = :right)
Expand Down