A simple restricted/expanded implementation of DataFrames.transform
, using Dagger, for generic Tables.
Created
June 26, 2021 23:55
-
-
Save ericphanson/03c9905a24b5b40f0108edb381ae2da7 to your computer and use it in GitHub Desktop.
TransformDagger
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# This file is machine-generated - editing it directly is not advised | |
[[ArgTools]] | |
uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" | |
[[Artifacts]] | |
uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" | |
[[Base64]] | |
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" | |
[[ColorTypes]] | |
deps = ["FixedPointNumbers", "Random"] | |
git-tree-sha1 = "024fe24d83e4a5bf5fc80501a314ce0d1aa35597" | |
uuid = "3da002f7-5984-5a60-b8a6-cbb66c0b333f" | |
version = "0.11.0" | |
[[Colors]] | |
deps = ["ColorTypes", "FixedPointNumbers", "Reexport"] | |
git-tree-sha1 = "417b0ed7b8b838aa6ca0a87aadf1bb9eb111ce40" | |
uuid = "5ae59095-9a9b-59fe-a467-6f913c188581" | |
version = "0.12.8" | |
[[Compat]] | |
deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"] | |
git-tree-sha1 = "dc7dedc2c2aa9faf59a55c622760a25cbefbe941" | |
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" | |
version = "3.31.0" | |
[[Crayons]] | |
git-tree-sha1 = "3f71217b538d7aaee0b69ab47d9b7724ca8afa0d" | |
uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f" | |
version = "4.0.4" | |
[[Dagger]] | |
deps = ["Colors", "Distributed", "LinearAlgebra", "MemPool", "Profile", "Random", "Requires", "Serialization", "SharedArrays", "SparseArrays", "Statistics", "StatsBase"] | |
git-tree-sha1 = "8d59bf882d9c8a1e5eb64207ee830e4efdfdc940" | |
uuid = "d58978e5-989f-55fb-8d15-ea34adc7bf54" | |
version = "0.11.3" | |
[[DataAPI]] | |
git-tree-sha1 = "ee400abb2298bd13bfc3df1c412ed228061a2385" | |
uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" | |
version = "1.7.0" | |
[[DataFrames]] | |
deps = ["Compat", "DataAPI", "Future", "InvertedIndices", "IteratorInterfaceExtensions", "LinearAlgebra", "Markdown", "Missings", "PooledArrays", "PrettyTables", "Printf", "REPL", "Reexport", "SortingAlgorithms", "Statistics", "TableTraits", "Tables", "Unicode"] | |
git-tree-sha1 = "66ee4fe515a9294a8836ef18eea7239c6ac3db5e" | |
uuid = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" | |
version = "1.1.1" | |
[[DataStructures]] | |
deps = ["Compat", "InteractiveUtils", "OrderedCollections"] | |
git-tree-sha1 = "4437b64df1e0adccc3e5d1adbc3ac741095e4677" | |
uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" | |
version = "0.18.9" | |
[[DataValueInterfaces]] | |
git-tree-sha1 = "bfc1187b79289637fa0ef6d4436ebdfe6905cbd6" | |
uuid = "e2d170a0-9d28-54be-80f0-106bbe20a464" | |
version = "1.0.0" | |
[[Dates]] | |
deps = ["Printf"] | |
uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" | |
[[DelimitedFiles]] | |
deps = ["Mmap"] | |
uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab" | |
[[Distributed]] | |
deps = ["Random", "Serialization", "Sockets"] | |
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" | |
[[Downloads]] | |
deps = ["ArgTools", "LibCURL", "NetworkOptions"] | |
uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" | |
[[ExprTools]] | |
git-tree-sha1 = "10407a39b87f29d47ebaca8edbc75d7c302ff93e" | |
uuid = "e2ba6199-217a-4e67-a87a-7c52f15ade04" | |
version = "0.1.3" | |
[[FixedPointNumbers]] | |
deps = ["Statistics"] | |
git-tree-sha1 = "335bfdceacc84c5cdf16aadc768aa5ddfc5383cc" | |
uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93" | |
version = "0.8.4" | |
[[Formatting]] | |
deps = ["Printf"] | |
git-tree-sha1 = "8339d61043228fdd3eb658d86c926cb282ae72a8" | |
uuid = "59287772-0a20-5a39-b81b-1366585eb4c0" | |
version = "0.4.2" | |
[[Future]] | |
deps = ["Random"] | |
uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820" | |
[[InteractiveUtils]] | |
deps = ["Markdown"] | |
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" | |
[[InvertedIndices]] | |
deps = ["Test"] | |
git-tree-sha1 = "15732c475062348b0165684ffe28e85ea8396afc" | |
uuid = "41ab1584-1d38-5bbf-9106-f11c6c58b48f" | |
version = "1.0.0" | |
[[IteratorInterfaceExtensions]] | |
git-tree-sha1 = "a3f24677c21f5bbe9d2a714f95dcd58337fb2856" | |
uuid = "82899510-4779-5014-852e-03e436cf321d" | |
version = "1.0.0" | |
[[LibCURL]] | |
deps = ["LibCURL_jll", "MozillaCACerts_jll"] | |
uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21" | |
[[LibCURL_jll]] | |
deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"] | |
uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0" | |
[[LibGit2]] | |
deps = ["Base64", "NetworkOptions", "Printf", "SHA"] | |
uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" | |
[[LibSSH2_jll]] | |
deps = ["Artifacts", "Libdl", "MbedTLS_jll"] | |
uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8" | |
[[Libdl]] | |
uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" | |
[[LinearAlgebra]] | |
deps = ["Libdl"] | |
uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" | |
[[Logging]] | |
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" | |
[[Markdown]] | |
deps = ["Base64"] | |
uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" | |
[[MbedTLS_jll]] | |
deps = ["Artifacts", "Libdl"] | |
uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" | |
[[MemPool]] | |
deps = ["DataStructures", "Distributed", "Mmap", "Random", "Serialization", "Sockets"] | |
git-tree-sha1 = "cb17c1dff8d9c89065c55ac4b0222b93d147e983" | |
uuid = "f9f48841-c794-520a-933b-121f7ba6ed94" | |
version = "0.3.4" | |
[[Missings]] | |
deps = ["DataAPI"] | |
git-tree-sha1 = "4ea90bd5d3985ae1f9a908bd4500ae88921c5ce7" | |
uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28" | |
version = "1.0.0" | |
[[Mmap]] | |
uuid = "a63ad114-7e13-5084-954f-fe012c677804" | |
[[MozillaCACerts_jll]] | |
uuid = "14a3606d-f60d-562e-9121-12d972cd8159" | |
[[NetworkOptions]] | |
uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" | |
[[OrderedCollections]] | |
git-tree-sha1 = "85f8e6578bf1f9ee0d11e7bb1b1456435479d47c" | |
uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" | |
version = "1.4.1" | |
[[Pkg]] | |
deps = ["Artifacts", "Dates", "Downloads", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] | |
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" | |
[[PooledArrays]] | |
deps = ["DataAPI", "Future"] | |
git-tree-sha1 = "cde4ce9d6f33219465b55162811d8de8139c0414" | |
uuid = "2dfb63ee-cc39-5dd5-95bd-886bf059d720" | |
version = "1.2.1" | |
[[PrettyTables]] | |
deps = ["Crayons", "Formatting", "Markdown", "Reexport", "Tables"] | |
git-tree-sha1 = "0d1245a357cc61c8cd61934c07447aa569ff22e6" | |
uuid = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" | |
version = "1.1.0" | |
[[Printf]] | |
deps = ["Unicode"] | |
uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" | |
[[Profile]] | |
deps = ["Printf"] | |
uuid = "9abbd945-dff8-562f-b5e8-e1ebf5ef1b79" | |
[[REPL]] | |
deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"] | |
uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" | |
[[Random]] | |
deps = ["Serialization"] | |
uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" | |
[[Reexport]] | |
git-tree-sha1 = "5f6c21241f0f655da3952fd60aa18477cf96c220" | |
uuid = "189a3867-3050-52da-a836-e630ba90ab69" | |
version = "1.1.0" | |
[[Requires]] | |
deps = ["UUIDs"] | |
git-tree-sha1 = "4036a3bd08ac7e968e27c203d45f5fff15020621" | |
uuid = "ae029012-a4dd-5104-9daa-d747884805df" | |
version = "1.1.3" | |
[[SHA]] | |
uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" | |
[[Serialization]] | |
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" | |
[[SharedArrays]] | |
deps = ["Distributed", "Mmap", "Random", "Serialization"] | |
uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383" | |
[[Sockets]] | |
uuid = "6462fe0b-24de-5631-8697-dd941f90decc" | |
[[SortingAlgorithms]] | |
deps = ["DataStructures"] | |
git-tree-sha1 = "2ec1962eba973f383239da22e75218565c390a96" | |
uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c" | |
version = "1.0.0" | |
[[SparseArrays]] | |
deps = ["LinearAlgebra", "Random"] | |
uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" | |
[[Statistics]] | |
deps = ["LinearAlgebra", "SparseArrays"] | |
uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" | |
[[StatsAPI]] | |
git-tree-sha1 = "1958272568dc176a1d881acb797beb909c785510" | |
uuid = "82ae8749-77ed-4fe6-ae5f-f523153014b0" | |
version = "1.0.0" | |
[[StatsBase]] | |
deps = ["DataAPI", "DataStructures", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "StatsAPI"] | |
git-tree-sha1 = "2f6792d523d7448bbe2fec99eca9218f06cc746d" | |
uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" | |
version = "0.33.8" | |
[[TOML]] | |
deps = ["Dates"] | |
uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76" | |
[[TableTraits]] | |
deps = ["IteratorInterfaceExtensions"] | |
git-tree-sha1 = "c06b2f539df1c6efa794486abfb6ed2022561a39" | |
uuid = "3783bdb8-4a98-5b6b-af9a-565f29a5fe9c" | |
version = "1.0.1" | |
[[Tables]] | |
deps = ["DataAPI", "DataValueInterfaces", "IteratorInterfaceExtensions", "LinearAlgebra", "TableTraits", "Test"] | |
git-tree-sha1 = "8ed4a3ea724dac32670b062be3ef1c1de6773ae8" | |
uuid = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" | |
version = "1.4.4" | |
[[Tar]] | |
deps = ["ArgTools", "SHA"] | |
uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e" | |
[[Test]] | |
deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] | |
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" | |
[[TimerOutputs]] | |
deps = ["ExprTools", "Printf"] | |
git-tree-sha1 = "9f494bc54b4c31404a9eff449235836615929de1" | |
uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" | |
version = "0.5.10" | |
[[TrackingTimers]] | |
deps = ["Distributed", "PrettyTables", "Printf", "Tables"] | |
git-tree-sha1 = "16e1d1b40436284f6a0f3b965101da3f8c807564" | |
uuid = "88ba133c-8695-4d62-9a5c-bcf16b6d2e1a" | |
version = "0.1.2" | |
[[UUIDs]] | |
deps = ["Random", "SHA"] | |
uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" | |
[[Unicode]] | |
uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" | |
[[Zlib_jll]] | |
deps = ["Libdl"] | |
uuid = "83775a58-1f1d-513f-b197-d71354ab007a" | |
[[nghttp2_jll]] | |
deps = ["Artifacts", "Libdl"] | |
uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d" | |
[[p7zip_jll]] | |
deps = ["Artifacts", "Libdl"] | |
uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
[deps] | |
Dagger = "d58978e5-989f-55fb-8d15-ea34adc7bf54" | |
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" | |
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" | |
PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" | |
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" | |
TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" | |
TrackingTimers = "88ba133c-8695-4d62-9a5c-bcf16b6d2e1a" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using Distributed | |
nprocs() >= 2 || addprocs(2) | |
@everywhere using LinearAlgebra | |
@everywhere using TrackingTimers | |
@everywhere begin | |
# borrowed from DataFrames.jl | |
struct ByRow{F} | |
f::F | |
end | |
function (b::ByRow)(args...) | |
return [b.f((arg[i] for arg in args)...) for i in eachindex(args...)] | |
end | |
function hi(a, b) | |
return (; x=a * b, y=a + b) | |
end | |
svdfn(x, y) = svdvals(x * y') | |
function garbo(a) | |
@time begin | |
m = rand(length(a), length(a), length(a)) | |
m = rand(length(a), length(a), length(a)) | |
m = rand(length(a), length(a), length(a)) | |
m = rand(length(a), length(a), length(a)) | |
m = dropdims(sum(m; dims=3); dims=3) | |
r = diag(m) .+ a | |
end | |
return r | |
end | |
BLAS.set_num_threads(1) | |
end | |
stringify(f::ByRow) = string("ByRow(", stringify(f.f), ")") | |
include("transform.jl") | |
test_table = (; a=randn(1000), b=rand(1000)) | |
result, log, t = @time transform(test_table, :a => (x -> 2x) => :c, | |
(:a, :c) => svdfn => :svd, (:svd, :a) => (+) => :sum, | |
(:sum, :svd) => ByRow(hi) => [:x, :y], | |
[:y, :a] => (+) => :z, (:a, :b) => svdfn => :svd_ab, | |
:b => garbo => :g); | |
open("logs.gv"; write=true) do io | |
return Dagger.show_plan(io, Dagger.get_logs!(log)) | |
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using Dagger, Tables, OrderedCollections | |
using PrettyTables, TrackingTimers | |
# unpack syntax `input_cols => f => output_cols` | |
function decompose_pairs(p::Pair{<:Any,<:Pair}) | |
input = first(p) | |
f = first(last(p)) | |
output = last(last(p)) | |
return input, f, output | |
end | |
stringify(f) = repr(f; context=:compact => true) | |
function instrument(t::TrackingTimer, p::Pair{<:Any,<:Pair}) | |
input, f, output = decompose_pairs(p) | |
name = string(input, " ↦ ", stringify(f), " ↦ ", output) | |
return input => t(f, name) => output | |
end | |
wrap(input::Symbol) = tuple(input) | |
wrap(input) = input | |
columnify(f) = (args...) -> Tables.columns(f(args...)) | |
# input: any Tables.jl table. | |
# Output: an `OrderedDict` of columns, which is a Tables.jl column table | |
function transform(table, ps...) | |
t = TrackingTimer() | |
ctx = Context() | |
log = Dagger.LocalEventLog() | |
ctx.log_sink = log | |
tab = Tables.columns(table) | |
delayed_cols = OrderedDict{Symbol,Thunk}() | |
# Pre-populate with existing columns | |
for c in Tables.columnnames(tab) | |
col = delayed(Tables.getcolumn)(tab, c) | |
delayed_cols[c] = col | |
end | |
# Add in new columns from transformations | |
for p in ps | |
input, f, output = decompose_pairs(instrument(t, p)) | |
cols = (delayed_cols[i] for i in wrap(input)) | |
if length(wrap(output)) > 1 | |
result = delayed(columnify(f); cache=true)(cols...) | |
for col in output | |
# not sure if `getcolumn` should be cached here... it should be very cheap | |
delayed_cols[col] = delayed(t(Tables.getcolumn); cache=true)(result, | |
col) | |
end | |
else | |
delayed_cols[output] = delayed(f; cache=true)(cols...) | |
end | |
end | |
# Collect results | |
result = OrderedDict{Symbol,AbstractVector}() | |
for (k, v) in delayed_cols | |
result[k] = collect(ctx, v) | |
end | |
@info "Timing information" t | |
return result, log, t | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Results: