Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@ authors = ["Éric Thiébaut <eric.thiebaut@univ-lyon1.fr> and contributors"]

[deps]
InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"

[weakdeps]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"

[extensions]
TypeUtilsLinearAlgebraExt = "LinearAlgebra"
TypeUtilsOffsetArraysExt = "OffsetArrays"
TypeUtilsUnitfulExt = "Unitful"

Expand All @@ -28,9 +29,10 @@ julia = "1"

[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"

[targets]
test = ["Aqua", "Test", "OffsetArrays", "Unitful"]
test = ["Aqua", "LinearAlgebra", "Test", "OffsetArrays", "Unitful"]
42 changes: 42 additions & 0 deletions ext/TypeUtilsLinearAlgebraExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
module TypeUtilsLinearAlgebraExt

if isdefined(Base, :get_extension)
using TypeUtils, LinearAlgebra
else
using ..TypeUtils, ..LinearAlgebra
end

# Convert element type for LinearAlgebra factorizations.
# `LinearAlgebra.Factorization{T}(A)` can be used to convert element type of `A` for QR,
# LinearAlgebra.QRCompactWY, QRPivoted, LQ, Cholesky, CholeskyPivoted, LU, LDLt,
# BunchKaufman, SVD, etc.
TypeUtils.convert_eltype(::Type{T}, A::Factorization{T}) where {T} = A
TypeUtils.convert_eltype(::Type{T}, A::Factorization) where {T} = Factorization{T}(A)
if VERSION < v"1.7.0-beta1"
# For old Julia versions, the above is not sufficient for SVD.
TypeUtils.convert_eltype(::Type{T}, A::SVD{T}) where {T} = A
TypeUtils.convert_eltype(::Type{T}, A::SVD) where {T} =
SVD(TypeUtils.convert_eltype(T, A.U), TypeUtils.convert_eltype(real(T), A.S), TypeUtils.convert_eltype(T, A.Vt))
end
TypeUtils.convert_eltype(::Type{T}, A::Hessenberg{T}) where {T} = A
TypeUtils.convert_eltype(::Type{T}, A::Hessenberg) where {T} = throw(
ArgumentError(
"changing element type of Hessenberg decomposition is not supported, consider re-computing the decomposition"
)
)

# For `Adjoint` and `Transpose`, we want to preserve this structure.
for S in (:Adjoint, :Transpose)
@eval begin
TypeUtils.convert_eltype(::Type{T}, A::$S{T}) where {T} = A
TypeUtils.convert_eltype(::Type{T}, A::$S) where {T} = $S(TypeUtils.convert_eltype(T, parent(A)))
end
end

# Get and adapt precision for LinearAlgebra factorizations.
TypeUtils.get_precision(::Type{<:Factorization{T}}) where {T} = TypeUtils.get_precision(T)
TypeUtils.adapt_precision(::Type{T}, A::Factorization{T}) where {T <: TypeUtils.Precision} = A
TypeUtils.adapt_precision(::Type{T}, A::Factorization{S}) where {T <: TypeUtils.Precision, S} =
TypeUtils.convert_eltype(TypeUtils.adapt_precision(T, S), A)

end # module
3 changes: 2 additions & 1 deletion src/TypeUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ include("macros.jl")
default_precision

using Base: OneTo
using LinearAlgebra
if !isdefined(Base, :get_extension)
using Requires
end
Expand All @@ -73,6 +72,8 @@ import .LazyMaps: LazyMap, lazymap
function __init__()
@static if !isdefined(Base, :get_extension)
# Extend methods when other packages are loaded.
@require LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" include(
"../ext/TypeUtilsLinearAlgebraExt.jl")
@require Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" include(
"../ext/TypeUtilsUnitfulExt.jl")
@require OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" include(
Expand Down
23 changes: 0 additions & 23 deletions src/arrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -165,29 +165,6 @@ convert_eltype(::Type{T}, A::AbstractArray) where {T} = AbstractArray{T}(A)
convert_eltype(::Type{T}, ::Type{<:Array{<:Any,N}}) where {T,N} = Array{T,N}
convert_eltype(::Type{T}, ::Type{<:AbstractArray{<:Any,N}}) where {T,N} = AbstractArray{T,N}

# `LinearAlgebra.Factorization{T}(A)` can be used to convert element type of `A` for QR,
# LinearAlgebra.QRCompactWY, QRPivoted, LQ, Cholesky, CholeskyPivoted, LU, LDLt,
# BunchKaufman, SVD, etc.
convert_eltype(::Type{T}, A::Factorization{T}) where {T} = A
convert_eltype(::Type{T}, A::Factorization) where {T} = Factorization{T}(A)
if VERSION < v"1.7.0-beta1"
# For old Julia versions, the above is not sufficient for SVD.
convert_eltype(::Type{T}, A::SVD{T}) where {T} = A
convert_eltype(::Type{T}, A::SVD) where {T} =
SVD(convert_eltype(T, A.U), convert_eltype(real(T), A.S), convert_eltype(T, A.Vt))
end
convert_eltype(::Type{T}, A::Hessenberg{T}) where {T} = A
convert_eltype(::Type{T}, A::Hessenberg) where {T} = throw(ArgumentError(
"changing element type of Hessenberg decomposition is not supported, consider re-computing the decomposition"))

# For `Adjoint` and `Transpose`, we want to preserve this structure.
for S in (:Adjoint, :Transpose)
@eval begin
convert_eltype(::Type{T}, A::$S{T}) where {T} = A
convert_eltype(::Type{T}, A::$S) where {T} = $S(convert_eltype(T, parent(A)))
end
end

# Convert element type for numbers.
convert_eltype(::Type{T}, ::Type{<:Number}) where {T} = T

Expand Down
6 changes: 0 additions & 6 deletions src/precision.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ get_precision(::Type) = AbstractFloat # pass-through
get_precision(::Type{T}) where {T<:Precision} = T
get_precision(::Type{<:Complex{T}}) where {T} = get_precision(T)
get_precision(::Type{<:AbstractArray{T}}) where {T} = get_precision(T)
get_precision(::Type{<:Factorization{T}}) where {T} = get_precision(T)

# Second type parameter of a named tuple is a tuple of types.
get_precision(::Type{NamedTuple{S,T}}) where {S,T} = get_precision(T)
Expand Down Expand Up @@ -118,11 +117,6 @@ adapt_precision(::Type{T}, A::AbstractArray{T}) where {T<:Precision} = A
adapt_precision(::Type{T}, A::AbstractArray{S}) where {T<:Precision,S} =
convert_eltype(adapt_precision(T, S), A)

# Adapt precision of factorizations.
adapt_precision(::Type{T}, A::Factorization{T}) where {T<:Precision} = A
adapt_precision(::Type{T}, A::Factorization{S}) where {T<:Precision,S} =
convert_eltype(adapt_precision(T, S), A)

# Set precision for tuples.
adapt_precision(::Type{T}, x::Union{Tuple,NamedTuple}) where {T<:Precision} =
map(adapt_precision(T), x)
Expand Down