Last active
February 21, 2018 12:00
-
-
Save andyferris/44ce332118d582e52a554739e1b1286b to your computer and use it in GitHub Desktop.
Implementation of `broadcast` for `AbstractDict`s and `NamedTuple`s
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
function Base.Dict{K,V}(::Uninitialized, inds) where {K, V} | |
set = Set{K}(inds) | |
d = set.dict | |
n = length(d.keys) | |
Dict{K,V}(d.slots, d.keys, Vector{V}(uninitialized, n), d.ndel, d.count, d.age, d.idxfloor, d.maxprobe) | |
end | |
function Base.Dict{K,V}(::Uninitialized, inds::Set{K}) where {K, V} | |
d = inds.dict | |
n = length(d.keys) | |
Dict{K,V}(copy(d.slots), copy(d.keys), Vector{V}(uninitialized, n), d.ndel, d.count, d.age, d.idxfloor, d.maxprobe) | |
end | |
function Base.Dict{K,V}(::Uninitialized, inds::Base.KeySet{<:Dict{K}}) where {K, V} | |
d = inds.dict | |
n = length(d.keys) | |
Dict{K,V}(copy(d.slots), copy(d.keys), Vector{V}(uninitialized, n), d.ndel, d.count, d.age, d.idxfloor, d.maxprobe) | |
end | |
## redefine with inlines | |
@inline Broadcast.combine_indices(A, B...) = Broadcast.broadcast_shape(Broadcast.broadcast_indices(A), Broadcast.combine_indices(B...)) | |
@inline Broadcast.combine_indices(A) = Broadcast.broadcast_indices(A) | |
## | |
Broadcast.BroadcastStyle(::Type{<:NamedTuple}) = Broadcast.Style{NamedTuple}() | |
Broadcast.BroadcastStyle(::Broadcast.Style{NamedTuple}, ::Broadcast.Scalar) = Broadcast.Style{NamedTuple}() | |
Broadcast.BroadcastStyle(::Broadcast.Style{NamedTuple}, ::Broadcast.Style{Tuple}) = Broadcast.Style{NamedTuple}() | |
Broadcast.BroadcastStyle(::Broadcast.Style{NamedTuple}, ::Broadcast.AbstractArrayStyle) = Broadcast.Style{NamedTuple}() | |
struct DictStyle <: Broadcast.BroadcastStyle end | |
Broadcast.BroadcastStyle(::Type{<:AbstractDict}) = DictStyle() | |
Broadcast.BroadcastStyle(::DictStyle, ::Broadcast.Scalar) = DictStyle() | |
Broadcast.BroadcastStyle(::DictStyle, ::Broadcast.Style{Tuple}) = DictStyle() | |
Broadcast.BroadcastStyle(::DictStyle, ::Broadcast.Style{NamedTuple}) = DictStyle() | |
Broadcast.BroadcastStyle(::DictStyle, ::Broadcast.AbstractArrayStyle) = DictStyle() | |
Broadcast.broadcast_similar(f, ::DictStyle, ::Type{ElType}, inds::Tuple{Any}, As...) where {ElType} = Dict{eltype(inds[1]), ElType}(uninitialized, inds[1]) | |
@inline Broadcast.broadcast_indices(::Broadcast.Style{NamedTuple}, ::NamedTuple{names}) where {names} = (names,) | |
Broadcast.broadcast_indices(::DictStyle, d) = (Base.keys(d),) | |
Broadcast._bcs1(a::AbstractSet, b::AbstractSet) = Broadcast._bcsm(a, b) ? a : throw(DimensionMismatch("containers could not be broadcast to a common size")) | |
Broadcast._bcs1(a, b::AbstractSet) = Broadcast._bcsm(a, b) ? b : throw(DimensionMismatch("containers could not be broadcast to a common size")) | |
Broadcast._bcs1(a::AbstractSet, b) = Broadcast._bcsm(a, b) ? a : throw(DimensionMismatch("containers could not be broadcast to a common size")) | |
@inline Broadcast._bcs1(a::Tuple{Vararg{Symbol}}, b::Tuple{Vararg{Symbol}}) = Broadcast._bcsm(a, b) ? a : throw(DimensionMismatch("containers could not be broadcast to a common size")) | |
Broadcast._bcs1(a::Tuple{Vararg{Symbol}}, b::AbstractSet) = Broadcast._bcsm(a, b) ? b : throw(DimensionMismatch("containers could not be broadcast to a common size")) | |
Broadcast._bcs1(a::AbstractSet, b::Tuple{Vararg{Symbol}}) = Broadcast._bcsm(a, b) ? a : throw(DimensionMismatch("containers could not be broadcast to a common size")) | |
Broadcast._bcs1(a::Tuple{Vararg{Symbol}}, b) = Broadcast._bcsm(a, b) ? a : throw(DimensionMismatch("containers could not be broadcast to a common size")) | |
Broadcast._bcs1(a, b::Tuple{Vararg{Symbol}}) = Broadcast._bcsm(a, b) ? b : throw(DimensionMismatch("containers could not be broadcast to a common size")) | |
Broadcast._bcsm(a::AbstractSet, b::AbstractUnitRange) = length(b) == 1 || issetequal(a, b) | |
Broadcast._bcsm(a::AbstractUnitRange, b::AbstractSet) = length(a) == 1 || issetequal(a, b) | |
Broadcast._bcsm(a::AbstractSet, b::AbstractSet) = a == b | |
Broadcast._bcsm(a::AbstractSet, b::Tuple{Vararg{Symbol}}) = issetequal(a, b) | |
Broadcast._bcsm(a::Tuple{Vararg{Symbol}}, b::AbstractSet) = issetequal(a, b) | |
@inline Broadcast._bcsm(a::Tuple{Vararg{Symbol}}, b::Tuple{Vararg{Symbol}}) = _issetequal(a, b) | |
Base.@pure _issetequal(a::Tuple{Vararg{Symbol}}, b::Tuple{Vararg{Symbol}}) = issetequal(a, b) | |
Broadcast._broadcast_getindex_eltype(::DictStyle, d) = valtype(d) | |
@inline function Broadcast.broadcast(f, s::Broadcast.BroadcastStyle, ::Type{ElType}, inds::Tuple{AbstractSet}, As...) where ElType | |
dest = Broadcast.broadcast_similar(f, s, ElType, inds, As...) | |
@inbounds broadcast!(f, dest, As...) | |
end | |
@inline function Broadcast.broadcast(f, s::Broadcast.Style{NamedTuple}, ::Type{ElType}, inds::Tuple{Tuple{Vararg{Symbol}}}, As...) where ElType | |
NamedTuple{inds[1]}(_broadcast(f, s, inds[1], As...)) | |
end | |
@inline function _broadcast(f, s::Broadcast.Style{NamedTuple}, inds::Tuple{Vararg{Symbol}}, As...) | |
i1 = inds[1] | |
i_tail = Base.tail(inds) | |
(f(map(a -> _getindex(a, i1), As)...), _broadcast(f, s, i_tail, As...)...) | |
end | |
@inline _broadcast(f, s::Broadcast.Style{NamedTuple}, inds::Tuple{}, As...) = () | |
function Broadcast.check_broadcast_indices(out_inds::AbstractSet, d::AbstractDict) | |
if !issetequal(out_inds, keys(d)) | |
throw(ErrorException("Output broadcast indices $out_inds do not match input $(keys(d))")) | |
end | |
end | |
function Broadcast.check_broadcast_indices(out_inds::AbstractSet, d::AbstractVector) | |
if !(length(d) === 1 || issetequal(out_inds, keys(d))) | |
throw(ErrorException("Output broadcast indices $out_inds do not match input $(keys(d))")) | |
end | |
end | |
function Broadcast.check_broadcast_indices(out_inds::AbstractSet, t::Tuple) | |
if !(length(t) === 1 || issetequal(out_inds, keys(t))) | |
throw(ErrorException("Output broadcast indices $out_inds do not match input $(keys(t))")) | |
end | |
end | |
function Broadcast.check_broadcast_indices(out_inds::AbstractSet, d::NamedTuple) | |
if !issetequal(out_inds, keys(d)) | |
throw(ErrorException("Output broadcast indices $out_inds do not match input $(keys(d))")) | |
end | |
end | |
Broadcast.check_broadcast_indices(::AbstractSet, ::Any) = nothing | |
Broadcast.check_broadcast_indices(::AbstractSet, ::AbstractArray{<:Any,0}) = nothing | |
Broadcast.check_broadcast_indices(::AbstractSet, ::AbstractArray) = throw(ErrorException("Broadcasting between dictionaries and multidimensional arrays is not supported")) | |
@inline function Broadcast.broadcast!(f, dest::AbstractDict, ::Broadcast.Scalar, As::Vararg{Any, N}) where N | |
@inbounds for i in keys(dest) | |
dest[i] = f(As...) | |
end | |
return dest | |
end | |
@inline function Broadcast.broadcast!(f, dest::AbstractDict, ::Broadcast.BroadcastStyle, As::Vararg{Any, N}) where N | |
inds = keys(dest) | |
@boundscheck map(a -> Broadcast.check_broadcast_indices(inds, a), As) | |
@inbounds for i in inds | |
dest[i] = f(map(a -> _getindex(a, i), As)...) | |
end | |
return dest | |
end | |
@inline function Broadcast.broadcast!(f, dest::AbstractVector, ::DictStyle, As::Vararg{Any, N}) where N | |
@inbounds for i in keys(dest) | |
dest[i] = f(map(a -> _getindex(a, i), As)...) | |
end | |
return dest | |
end | |
@inline _getindex(a, i) = a | |
@inline _getindex(a::AbstractDict, i) = @inbounds a[i] | |
@inline function _getindex(a::AbstractVector, i) | |
if length(a) === 1 | |
return @inbounds first(a) | |
else | |
return @inbounds a[i] | |
end | |
end | |
@inline _getindex(a::AbstractArray{<:Any, 0}, i) = @inbounds a[] | |
@inline _getindex(a::Tuple, i) = @inbounds a[i] | |
@inline _getindex(a::Tuple{Any}, i) = @inbounds a[1] | |
@inline _getindex(a::NamedTuple, i) = @inbounds a[i] | |
using Test | |
@testset "Broadcast dictionaries" begin | |
d = Dict(1 => 10, 2 => 20) | |
# Single argument `broadcast` | |
@test (d .* 2)::Dict{Int, Int} == Dict(1 => 20, 2 => 40) | |
@test (d .* 2.0)::Dict{Int, Float64} == Dict(1 => 20.0, 2 => 40.0) | |
# Two argument `broacast` | |
x = 2 | |
@test (d .* x)::Dict{Int, Int} == Dict(1 => 20, 2 => 40) | |
@test (d .* d)::Dict{Int, Int} == Dict(1 => 100, 2 => 400) | |
@test (d .+ [1, 2])::Dict{Int, Int} == Dict(1 => 11, 2 => 22) | |
@test (d .+ (1, 2))::Dict{Int, Int} == Dict(1 => 11, 2 => 22) | |
@test (d .+ [1])::Dict{Int, Int} == Dict(1 => 11, 2 => 21) | |
@test (d .+ (1,))::Dict{Int, Int} == Dict(1 => 11, 2 => 21) | |
@test (d .+ fill(1))::Dict{Int, Int} == Dict(1 => 11, 2 => 21) # zero-dimensional array | |
@test Dict(:a=>1, :b=>2) .+ (a=1, b=2) == Dict(:a=>2, :b=>4) | |
# Mutating `broadcast!` | |
d2 = copy(d) | |
d2 .= 0 | |
@test d2 == Dict(1 => 0, 2 => 0) | |
d2 .= [1] | |
@test d2 == Dict(1 => 1, 2 => 1) | |
d2 .= (2,) | |
@test d2 == Dict(1 => 2, 2 => 2) | |
d2 .= Dict(1 => 3, 2 => 4) | |
@test d2 == Dict(1 => 3, 2 => 4) | |
d2 .= [5, 6] | |
@test d2 == Dict(1 => 5, 2 => 6) | |
d2 .= (7, 8) | |
@test d2 == Dict(1 => 7, 2 => 8) | |
a = [0, 0] | |
a .= d | |
@test a == [10, 20] | |
d3 = Dict(:a=>0, :b=>0) | |
d3 .= (a=1, b=2) | |
@test d3 == Dict(:a=>1, :b=>2) | |
end | |
@testset "Broadcast named tuples" begin | |
@test (a=1, b=2) .+ 1 === (a=2, b=3) | |
@test (a=1, b=2) .+ [1] === (a=2, b=3) | |
@test (a=1, b=2) .+ fill(1) === (a=2, b=3) | |
@test (a=1, b=2) .+ (1,) === (a=2, b=3) | |
@test (a=1, b=2) .+ (a=1, b=2) === (a=2, b=4) | |
@test (a=1, b=2) .+ (b=2, a=1) === (a=2, b=4) | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
The biggest remaining issue that I'm aware of is the speed of named tuple broadcasting (constant propagation of the names).