Skip to content

Commit

Permalink
Make return type of broadcast inferrable with heterogeneous arrays (J…
Browse files Browse the repository at this point in the history
…uliaLang#30485)

Inference is not able to detect the element type automatically, but we can do it manually
since we know promote_typejoin is used for widening.
  • Loading branch information
nalimilan authored Oct 27, 2020
1 parent 03c5eee commit 65c7bf5
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 4 deletions.
54 changes: 51 additions & 3 deletions base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ Module containing the broadcasting implementation.
module Broadcast

using .Base.Cartesian
using .Base: Indices, OneTo, tail, to_shape, isoperator, promote_typejoin,
using .Base: Indices, OneTo, tail, to_shape, isoperator, promote_typejoin, @pure,
_msk_end, unsafe_bitgetindex, bitcache_chunks, bitcache_size, dumpbitcache, unalias
import .Base: copy, copyto!, axes
export broadcast, broadcast!, BroadcastStyle, broadcast_axes, broadcastable, dotview, @__dot__, broadcast_preserving_zero_d, BroadcastFunction
Expand Down Expand Up @@ -691,8 +691,52 @@ eltypes(t::Tuple{Any}) = Tuple{_broadcast_getindex_eltype(t[1])}
eltypes(t::Tuple{Any,Any}) = Tuple{_broadcast_getindex_eltype(t[1]), _broadcast_getindex_eltype(t[2])}
eltypes(t::Tuple) = Tuple{_broadcast_getindex_eltype(t[1]), eltypes(tail(t)).types...}

function promote_typejoin_union(::Type{T}) where T
if T === Union{}
return Union{}
elseif T isa Union
return promote_typejoin(promote_typejoin_union(T.a), promote_typejoin_union(T.b))
elseif T <: Tuple
return typejoin_union_tuple(T)
else
return T
end
end

@pure function typejoin_union_tuple(T::Type)
u = Base.unwrap_unionall(T)
u isa Union && return typejoin(
typejoin_union_tuple(Base.rewrap_unionall(u.a, T)),
typejoin_union_tuple(Base.rewrap_unionall(u.b, T)))
p = (u::DataType).parameters
lr = length(p)::Int
if lr == 0
return Tuple{}
end
c = Vector{Any}(undef, lr)
for i = 1:lr
pi = p[i]
U = Core.Compiler.unwrapva(pi)
if U === Union{}
ci = Union{}
elseif U isa Union
ci = typejoin(U.a, U.b)
else
ci = U
end
if i == lr && Core.Compiler.isvarargtype(pi)
N = (Base.unwrap_unionall(pi)::DataType).parameters[2]
c[i] = Base.rewrap_unionall(Vararg{ci, N}, pi)
else
c[i] = ci
end
end
return Base.rewrap_unionall(Tuple{c...}, T)
end

# Inferred eltype of result of broadcast(f, args...)
combine_eltypes(f, args::Tuple) = Base._return_type(f, eltypes(args))
combine_eltypes(f, args::Tuple) =
promote_typejoin_union(Base._return_type(f, eltypes(args)))

## Broadcasting core

Expand Down Expand Up @@ -877,7 +921,11 @@ const NonleafHandlingStyles = Union{DefaultArrayStyle,ArrayConflict}
dest = similar(bc′, typeof(val))
@inbounds dest[I] = val
# Now handle the remaining values
return copyto_nonleaf!(dest, bc′, iter, state, 1)
# The typeassert gives inference a helping hand on the element type and dimensionality
# (work-around for #28382)
ElType′ = ElType <: Type ? Type : ElType
RT = dest isa AbstractArray ? AbstractArray{<:ElType′, ndims(dest)} : Any
return copyto_nonleaf!(dest, bc′, iter, state, 1)::RT
end

## general `copyto!` methods
Expand Down
40 changes: 39 additions & 1 deletion test/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ end
let f17314 = x -> x < 0 ? false : x
@test eltype(broadcast(f17314, 1:3)) === Int
@test eltype(broadcast(f17314, -1:1)) === Integer
@test eltype(broadcast(f17314, Int[])) == Union{Bool,Int}
@test eltype(broadcast(f17314, Int[])) === Integer
end
let io = IOBuffer()
broadcast(x->print(io,x), 1:5) # broadcast with side effects
Expand Down Expand Up @@ -950,3 +950,41 @@ p0 = copy(p)
@test map(.+, [[1,2], [3,4]], [5, 6]) == [[6,7], [9,10]]
@test repr(.!) == "Base.Broadcast.BroadcastFunction(!)"
@test eval(:(.+)) == Base.BroadcastFunction(+)

@testset "Issue #28382: inferrability of broadcast with Union eltype" begin
@test isequal([1, 2] .+ [3.0, missing], [4.0, missing])
@test_broken Core.Compiler.return_type(broadcast, Tuple{typeof(+), Vector{Int},
Vector{Union{Float64, Missing}}}) ==
Vector{<:Union{Float64, Missing}}
@test Core.Compiler.return_type(broadcast, Tuple{typeof(+), Vector{Int},
Vector{Union{Float64, Missing}}}) ==
AbstractVector{<:Union{Float64, Missing}}
@test isequal([1, 2] + [3.0, missing], [4.0, missing])
@test_broken Core.Compiler.return_type(+, Tuple{Vector{Int},
Vector{Union{Float64, Missing}}}) ==
Vector{<:Union{Float64, Missing}}
@test Core.Compiler.return_type(+, Tuple{Vector{Int},
Vector{Union{Float64, Missing}}}) ==
AbstractVector{<:Union{Float64, Missing}}
@test_broken Core.Compiler.return_type(+, Tuple{Vector{Int},
Vector{Union{Float64, Missing}}}) ==
Vector{<:Union{Float64, Missing}}
@test isequal(tuple.([1, 2], [3.0, missing]), [(1, 3.0), (2, missing)])
@test_broken Core.Compiler.return_type(broadcast, Tuple{typeof(tuple), Vector{Int},
Vector{Union{Float64, Missing}}}) ==
Vector{<:Tuple{Int, Any}}
@test Core.Compiler.return_type(broadcast, Tuple{typeof(tuple), Vector{Int},
Vector{Union{Float64, Missing}}}) ==
AbstractVector{<:Tuple{Int, Any}}
# Check that corner cases do not throw an error
@test isequal(broadcast(x -> x === 1 ? nothing : x, [1, 2, missing]),
[nothing, 2, missing])
@test isequal(broadcast(x -> x === 1 ? nothing : x, Any[1, 2, 3.0, missing]),
[nothing, 2, 3, missing])
@test broadcast((x,y)->(x==1 ? 1.0 : x, y), [1 2 3], ["a", "b", "c"]) ==
[(1.0, "a") (2, "a") (3, "a")
(1.0, "b") (2, "b") (3, "b")
(1.0, "c") (2, "c") (3, "c")]
@test typeof.([iszero, isdigit]) == [typeof(iszero), typeof(isdigit)]
@test typeof.([iszero, iszero]) == [typeof(iszero), typeof(iszero)]
end

0 comments on commit 65c7bf5

Please sign in to comment.