Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions src/structarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -452,18 +452,36 @@ function Base.sizehint!(s::StructArray, i::Integer)
return s
end

_cateltype(::Type{T}, newcols::Tup) where {T<:Tup} = eltypes(newcols)
_cateltype(::Type{T}, newcols::Tup) where {T} = T

function _reducecat_structarray(op, A::AbstractVector{<:StructArray})
isempty(A) && return Base.mapreduce_empty(eltype, promote_type, eltype(A))
cols = map(components, A)
firstcols = first(cols)
all(col -> keys(col) == keys(firstcols), cols) || throw(ArgumentError("StructArray columns must have matching keys."))
newcols = map(key -> reduce(op, map(Base.Fix2(getindex, key), cols)), keys(firstcols))
typedcols = strip_params(typeof(firstcols))(newcols)
T = _cateltype(mapreduce(eltype, promote_type, A), typedcols)
return StructArray{T}(typedcols)
end

for op in [:cat, :hcat, :vcat]
curried_op = Symbol(:curried, op)
@eval begin
function Base.$op(arg::StructArray, others::StructArray...; kwargs...)
$curried_op(A...) = $op(A...; kwargs...)
args = (arg, others...)
T = mapreduce(eltype, promote_type, args)
StructArray{T}(map($curried_op, map(components, args)...))
newcols = map($curried_op, map(components, args)...)
T = _cateltype(mapreduce(eltype, promote_type, args), newcols)
StructArray{T}(newcols)
end
end
end

Base.reduce(::typeof(vcat), A::AbstractVector{<:StructArray}) = _reducecat_structarray(vcat, A)
Base.reduce(::typeof(hcat), A::AbstractVector{<:StructArray}) = _reducecat_structarray(hcat, A)

Base.copy(s::StructArray{T}) where {T} = StructArray{T}(map(copy, components(s)))

for type in (
Expand Down
48 changes: 48 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -693,6 +693,54 @@ end
@test @inferred(vcat(t3)) == t3
@inferred vcat(t3, t3)
@inferred vcat(t3, collect(t3))
a = StructArray(y = Union{Missing, Int}[missing])
b = StructArray(y = [3])
c = StructArray(y = Union{Missing, Int}[4])
vcatted = vcat(a, b, c)
@test eltype(vcatted) === NamedTuple{(:y,), Tuple{Union{Missing, Int}}}
reduced_vcat = reduce(vcat, [a, b, c])
@test eltype(reduced_vcat) === eltype(vcatted)
@test isequal(reduced_vcat, vcatted)
@test reduced_vcat.y isa Vector{Union{Missing, Int}}
hcatted = hcat(reshape(a, 1, 1), reshape(b, 1, 1), reshape(c, 1, 1))
@test eltype(hcatted) === NamedTuple{(:y,), Tuple{Union{Missing, Int}}}
reduced_hcat = reduce(hcat, [reshape(a, 1, 1), reshape(b, 1, 1), reshape(c, 1, 1)])
@test eltype(reduced_hcat) === eltype(hcatted)
@test isequal(reduced_hcat, hcatted)
@test reduced_hcat.y isa Matrix{Union{Missing, Int}}

struct CatTestType{A, B}
a::A
b::B
end
custom_a = StructArray{CatTestType{Int, Missing}}((a = [1], b = Missing[missing]))
custom_b = StructArray{CatTestType{Int, Int}}((a = [2], b = [3]))
custom_vcat = vcat(custom_a, custom_b, custom_a)
@test custom_vcat == CatTestType{Int}[CatTestType(1, missing), CatTestType(2, 3), CatTestType(1, missing)]
@test custom_vcat.b isa Vector{Union{Missing, Int}}
reduced_custom_vcat = reduce(vcat, [custom_a, custom_b, custom_a])
@test isequal(reduced_custom_vcat, custom_vcat)
@test eltype(reduced_custom_vcat) === eltype(custom_vcat) === CatTestType{Int}
@test reduced_custom_vcat.b isa Vector{Union{Missing, Int}}

# error behavior is consistent between reduce(vcat) and vcat(), and is generally reasonable
mismatched_names_a = StructArray(a = [1], b = [2])
mismatched_names_b = StructArray(x = [3], y = [4])
@test_throws ArgumentError vcat(mismatched_names_a, mismatched_names_b)
@test_throws ArgumentError reduce(vcat, [mismatched_names_a, mismatched_names_b])
mixed_rowtype_a = StructArray(re = [1.0], im = [2.0])
mixed_rowtype_b = StructArray(ComplexF64[3 + 4im])
@test_throws ArgumentError vcat(mixed_rowtype_a, mixed_rowtype_b)
@test_throws ArgumentError reduce(vcat, [mixed_rowtype_a, mixed_rowtype_b])
different_names_a = StructArray(a = [1])
different_names_b = StructArray(x = [2], y = [3], z = [4])
@test_throws ArgumentError vcat(different_names_a, different_names_b)
@test_throws ArgumentError reduce(vcat, [different_names_a, different_names_b])
different_lengths_a = StructArray(([1], [2], [3]))
different_lengths_b = StructArray(([4], [5]))
@test_throws ArgumentError reduce(vcat, [different_lengths_a, different_lengths_b])
@test_throws ArgumentError reduce(hcat, [reshape(different_lengths_a, 1, 1), reshape(different_lengths_b, 1, 1)])

# Check that `cat(dims=1)` doesn't commit type piracy (#254)
# We only test that this works, the return value is immaterial
@test cat(dims=1) == vcat()
Expand Down
Loading