Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "TensorAlgebra"
uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
version = "0.7.19"
version = "0.7.20"
authors = ["ITensor developers <support@itensor.org> and contributors"]

[workspace]
Expand Down
70 changes: 70 additions & 0 deletions src/lazyarrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,13 @@ lazy_function(::typeof(*)) = *ₗ
lazy_function(::typeof(/)) = /ₗ
lazy_function(::typeof(\)) = \ₗ
lazy_function(::typeof(conj)) = conjed
lazy_function(::typeof(identity)) = identity
lazy_function(f::Base.Fix1{typeof(*), <:Number}) = Base.Fix1(*ₗ, f.x)
lazy_function(f::Base.Fix2{typeof(*), <:Number}) = Base.Fix2(*ₗ, f.x)
lazy_function(f::Base.Fix2{typeof(/), <:Number}) = Base.Fix2(/ₗ, f.x)

broadcast_is_linear(f, args...) = false
broadcast_is_linear(::typeof(identity), ::Base.AbstractArrayOrBroadcasted) = true
broadcast_is_linear(::typeof(+), ::Base.AbstractArrayOrBroadcasted...) = true
broadcast_is_linear(::typeof(-), ::Base.AbstractArrayOrBroadcasted) = true
function broadcast_is_linear(
Expand All @@ -50,13 +55,41 @@ function broadcast_is_linear(
end
broadcast_is_linear(::typeof(*), ::Number, ::Number) = true
broadcast_is_linear(::typeof(conj), ::Base.AbstractArrayOrBroadcasted) = true
function broadcast_is_linear(
::Base.Fix1{typeof(*), <:Number}, ::Base.AbstractArrayOrBroadcasted
)
return true
end
function broadcast_is_linear(
::Base.Fix2{typeof(*), <:Number}, ::Base.AbstractArrayOrBroadcasted
)
return true
end
function broadcast_is_linear(
::Base.Fix2{typeof(/), <:Number}, ::Base.AbstractArrayOrBroadcasted
)
return true
end
is_linear(x) = true
function is_linear(bc::BC.Broadcasted)
return broadcast_is_linear(bc.f, bc.args...) && all(is_linear, bc.args)
end

to_linear(x) = x
to_linear(bc::BC.Broadcasted) = lazy_function(bc.f)(to_linear.(bc.args)...)
function broadcast_error(style, f)
return throw(
ArgumentError(
"Only linear broadcast operations are supported for `$style`, got `$f`."
)
)
end
function broadcasted_linear(style::BC.BroadcastStyle, f, args...)
bc = BC.Broadcasted(style, f, args)
is_linear(bc) || broadcast_error(style, f)
return to_linear(bc)
end
broadcasted_linear(f, args...) = broadcasted_linear(BC.combine_styles(args...), f, args...)
# TODO: Use `Broadcast.broadcastable` interface for this?
to_broadcasted(x) = x
function to_broadcasted(a::AbstractArray)
Expand Down Expand Up @@ -136,6 +169,7 @@ similar_scaled(a::AbstractArray) = similar(unscaled(a))
similar_scaled(a::AbstractArray, elt::Type) = similar(unscaled(a), elt)
similar_scaled(a::AbstractArray, ax) = similar(unscaled(a), ax)
similar_scaled(a::AbstractArray, elt::Type, ax) = similar(unscaled(a), elt, ax)
getindex_scaled(a::AbstractArray, I...) = coeff(a) * getindex(unscaled(a), I...)
copyto!_scaled(dest::AbstractArray, src::AbstractArray) = add!(dest, src, true, false)
show_scaled(io::IO, a::AbstractArray) = show_lazy(io, a)
show_scaled(io::IO, mime::MIME"text/plain", a::AbstractArray) = show_lazy(io, mime, a)
Expand Down Expand Up @@ -227,6 +261,9 @@ macro scaledarray_base(ScaledArray, AbstractArray = :AbstractArray)
function Base.similar(a::$ScaledArray, elt::Type, ax::Dims)
return $TensorAlgebra.similar_scaled(a, elt, ax)
end
Base.@propagate_inbounds function Base.getindex(a::$ScaledArray, I...)
return $TensorAlgebra.getindex_scaled(a, I...)
end
function Base.copyto!(dest::$AbstractArray, src::$ScaledArray)
return $TensorAlgebra.copyto!_scaled(dest, src)
end
Expand Down Expand Up @@ -372,6 +409,7 @@ size_conj(a::AbstractArray) = size(conjed(a))
similar_conj(a::AbstractArray, elt::Type) = similar(conjed(a), elt)
similar_conj(a::AbstractArray, elt::Type, ax) = similar(conjed(a), elt, ax)
similar_conj(a::AbstractArray, ax) = similar(conjed(a), ax)
getindex_conj(a::AbstractArray, I...) = conj(getindex(conjed(a), I...))
copyto!_conj(dest::AbstractArray, src::AbstractArray) = add!(dest, src, true, false)
show_conj(io::IO, a::AbstractArray) = show_lazy(io, a)
show_conj(io::IO, mime::MIME"text/plain", a::AbstractArray) = show_lazy(io, mime, a)
Expand Down Expand Up @@ -424,6 +462,9 @@ macro conjarray_base(ConjArray, AbstractArray = :AbstractArray)
function Base.similar(a::$ConjArray, elt::Type, ax::Dims)
return $TensorAlgebra.similar_conj(a, elt, ax)
end
Base.@propagate_inbounds function Base.getindex(a::$ConjArray, I...)
return $TensorAlgebra.getindex_conj(a, I...)
end
function Base.copyto!(dest::$AbstractArray, src::$ConjArray)
return $TensorAlgebra.copyto!_conj(dest, src)
end
Expand Down Expand Up @@ -525,6 +566,7 @@ similar_add(a::AbstractArray, elt::Type) = similar(BC.Broadcasted(+, addends(a))
function similar_add(a::AbstractArray, elt::Type, ax)
return similar(BC.Broadcasted(+, addends(a)), elt, ax)
end
getindex_add(a::AbstractArray, I...) = sum(addend -> getindex(addend, I...), addends(a))
copyto!_add(dest::AbstractArray, src::AbstractArray) = add!(dest, src, true, false)
show_add(io::IO, a::AbstractArray) = show_lazy(io, a)
show_add(io::IO, mime::MIME"text/plain", a::AbstractArray) = show_lazy(io, mime, a)
Expand Down Expand Up @@ -611,6 +653,9 @@ macro addarray_base(AddArray, AbstractArray = :AbstractArray)
function Base.similar(a::$AddArray, elt::Type, ax)
return $TensorAlgebra.similar_add(a, elt, ax)
end
Base.@propagate_inbounds function Base.getindex(a::$AddArray, I...)
return $TensorAlgebra.getindex_add(a, I...)
end
function Base.copyto!(dest::$AbstractArray, src::$AddArray)
return $TensorAlgebra.copyto!_add(dest, src)
end
Expand Down Expand Up @@ -741,6 +786,20 @@ similar_mul(a::AbstractArray, elt::Type) = similar(a, elt, axes(a))
# TODO: Make use of both arguments to determine the output, maybe
# using `LinearAlgebra.matprod_dest(factors(a)..., elt)`?
similar_mul(a::AbstractArray, elt::Type, ax) = similar(last(factors(a)), elt, ax)
function mul_getindex(a1::AbstractMatrix, a2::AbstractMatrix, i::Int, j::Int)
return transpose(view(a1, i, :)) * view(a2, :, j)
end
function mul_getindex(a1::AbstractMatrix, a2::AbstractVector, i::Int)
return transpose(view(a1, i, :)) * a2
end
function mul_getindex(a1::AbstractVector, a2::AbstractMatrix, j::Int)
return transpose(a1) * view(a2, :, j)
end
function getindex_mul(a::AbstractArray, i::Int)
I = Tuple(CartesianIndices(axes(a))[i])
return getindex_mul(a, I...)
end
getindex_mul(a::AbstractArray, I::Vararg{Int}) = mul_getindex(factors(a)..., I...)
copyto!_mul(dest::AbstractArray, src::AbstractArray) = add!(dest, src, true, false)
show_mul(io::IO, a::AbstractArray) = show_lazy(io, a)
show_mul(io::IO, mime::MIME"text/plain", a::AbstractArray) = show_lazy(io, mime, a)
Expand Down Expand Up @@ -798,6 +857,11 @@ macro mularray_type(MulArray, AbstractArray = :AbstractArray)
)
end

function copy_permuteddims(a::PermutedDimsArray{<:Any, 2, perm}) where {perm}
perm == (1, 2) && return copy(parent(a))
return copy(transpose(parent(a)))
end

macro mularray_base(MulArray, AbstractArray = :AbstractArray)
return esc(
quote
Expand All @@ -819,6 +883,9 @@ macro mularray_base(MulArray, AbstractArray = :AbstractArray)
function Base.similar(a::$MulArray, elt::Type, ax::Dims)
return $TensorAlgebra.similar_mul(a, elt, ax)
end
Base.@propagate_inbounds function Base.getindex(a::$MulArray, I...)
return $TensorAlgebra.getindex_mul(a, I...)
end
function Base.copyto!(dest::$AbstractArray, src::$MulArray)
return $TensorAlgebra.copyto!_mul(dest, src)
end
Expand Down Expand Up @@ -881,6 +948,9 @@ macro mularray_terminterface(MulArray, AbstractArray = :AbstractArray)
$TensorAlgebra.iscall(a::$MulArray) = $TensorAlgebra.iscall_mul(a)
$TensorAlgebra.operation(a::$MulArray) = $TensorAlgebra.operation_mul(a)
$TensorAlgebra.arguments(a::$MulArray) = $TensorAlgebra.arguments_mul(a)
function Base.copy(a::PermutedDimsArray{<:Any, 2, <:Any, <:Any, $MulArray})
return $TensorAlgebra.copy_permuteddims(a)
end
end
)
end
Expand Down
32 changes: 29 additions & 3 deletions test/test_lazy.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import FunctionImplementations as FI
using TensorAlgebra: TensorAlgebra as TA, *ₗ, +ₗ, conjed
using Test: @test, @test_broken, @testset
using Base.Broadcast: Broadcast as BC
using TensorAlgebra: TensorAlgebra as TA, *ₗ, +ₗ, /ₗ, conjed
using Test: @test, @test_broken, @test_throws, @testset

@testset "lazy arrays" begin
@testset "lazy array operations" begin
Expand Down Expand Up @@ -92,6 +93,31 @@ using Test: @test, @test_broken, @testset

x = FI.permuteddims(a *ₗ b, perm)
@test x ≡ PermutedDimsArray(a *ₗ b, perm)
@test_broken copy(x) ≈ permutedims(a * b, perm)
@test copy(x) ≈ permutedims(a * b, perm)
end
@testset "linear broadcast lowering" begin
a = randn(ComplexF64, 2, 2)
style = BC.DefaultArrayStyle{2}()

@test TA.broadcasted_linear(identity, a) ≡ a
@test TA.broadcasted_linear(Base.Fix1(*, 2), a) ≡ 2 *ₗ a
@test TA.broadcasted_linear(Base.Fix2(*, 2), a) ≡ a *ₗ 2
@test TA.broadcasted_linear(Base.Fix2(/, 2), a) ≡ a /ₗ 2
@test TA.broadcasted_linear(style, identity, a) ≡ a
@test TA.broadcasted_linear(style, Base.Fix1(*, 2), a) ≡ 2 *ₗ a
@test TA.broadcasted_linear(style, Base.Fix2(*, 2), a) ≡ a *ₗ 2
@test TA.broadcasted_linear(style, Base.Fix2(/, 2), a) ≡ a /ₗ 2
@test TA.broadcasted_linear(style, conj, a) ≡ conjed(a)
@test_throws ArgumentError TA.broadcasted_linear(style, exp, a)
end
@testset "scalar getindex" begin
a = randn(ComplexF64, 2, 2)
b = randn(ComplexF64, 2, 2)

@test (2 *ₗ a)[1, 2] == 2 * a[1, 2]
@test conjed(a)[2, 1] == conj(a[2, 1])
@test (a +ₗ b)[2, 2] == a[2, 2] + b[2, 2]
@test (a *ₗ b)[1, 2] ≈ (a * b)[1, 2]
@test (a *ₗ b)[3] ≈ (a * b)[3]
end
end
Loading