Skip to content

Commit bbe77dd

Browse files
committed
Add the specialize_vararg macro by Mason Protter
1 parent d5bee93 commit bbe77dd

File tree

2 files changed

+57
-0
lines changed

2 files changed

+57
-0
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ Cassette = "7057c7e9-c182-5462-911a-8362d720325c"
88
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
99
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
1010
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
11+
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
1112
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1213
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1314

src/dual_context.jl

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,62 @@ using ChainRules
33
using ChainRulesCore
44
import ChainRulesCore: Zero
55

6+
# TODO: remove the copy pasted code and add that package
7+
# copyed from SpecializeVarargs.jl, written by @MasonProtter
8+
using MacroTools: MacroTools, splitdef, combinedef, @capture
9+
10+
macro specialize_vararg(n::Int, fdef::Expr)
11+
@assert n > 0
12+
13+
macros = Symbol[]
14+
while fdef.head == :macrocall && length(fdef.args) == 3
15+
push!(macros, fdef.args[1])
16+
fdef = fdef.args[3]
17+
end
18+
19+
d = splitdef(fdef)
20+
args = d[:args][end]
21+
@assert d[:args][end] isa Expr && d[:args][end].head == Symbol("...") && d[:args][end].args[] isa Symbol
22+
args_symbol = d[:args][end].args[]
23+
24+
fdefs = Expr(:block)
25+
26+
for i in 1:n-1
27+
di = deepcopy(d)
28+
pop!(di[:args])
29+
args = Tuple(gensym("arg$j") for j in 1:i)
30+
Ts = Tuple(gensym("T$j") for j in 1:i)
31+
32+
args_with_Ts = ((arg, T) -> :($arg :: $T)).(args, Ts)
33+
34+
di[:whereparams] = (di[:whereparams]..., Ts...)
35+
36+
push!(di[:args], args_with_Ts...)
37+
pushfirst!(di[:body].args, :($args_symbol = $(Expr(:tuple, args...))))
38+
cfdef = combinedef(di)
39+
mcfdef = isempty(macros) ? cfdef : foldr((m,f) -> Expr(:macrocall, m, nothing, f), macros, init=cfdef)
40+
push!(fdefs.args, mcfdef)
41+
end
42+
43+
di = deepcopy(d)
44+
pop!(di[:args])
45+
args = tuple((gensym() for j in 1:n)..., :($(gensym("args"))...))
46+
Ts = Tuple(gensym("T$j") for j in 1:n)
47+
48+
args_with_Ts = (((arg, T) -> :($arg :: $T)).(args[1:end-1], Ts)..., args[end])
49+
50+
di[:whereparams] = (di[:whereparams]..., Ts...)
51+
52+
push!(di[:args], args_with_Ts...)
53+
pushfirst!(di[:body].args, :($args_symbol = $(Expr(:tuple, args...))))
54+
55+
cfdef = combinedef(di)
56+
mcfdef = isempty(macros) ? cfdef : foldr((m,f) -> Expr(:macrocall, m, nothing, f), macros, init=cfdef)
57+
push!(fdefs.args, mcfdef)
58+
59+
esc(fdefs)
60+
end
61+
662
using Cassette: overdub, Context, nametype, similarcontext
763

864
Cassette.@context DualContext

0 commit comments

Comments
 (0)