@@ -3,6 +3,62 @@ using ChainRules
3
3
using ChainRulesCore
4
4
import ChainRulesCore: Zero
5
5
6
+ # TODO : remove the copy pasted code and add that package
7
+ # copyed from SpecializeVarargs.jl, written by @MasonProtter
8
+ using MacroTools: MacroTools, splitdef, combinedef, @capture
9
+
10
+ macro specialize_vararg (n:: Int , fdef:: Expr )
11
+ @assert n > 0
12
+
13
+ macros = Symbol[]
14
+ while fdef. head == :macrocall && length (fdef. args) == 3
15
+ push! (macros, fdef. args[1 ])
16
+ fdef = fdef. args[3 ]
17
+ end
18
+
19
+ d = splitdef (fdef)
20
+ args = d[:args ][end ]
21
+ @assert d[:args ][end ] isa Expr && d[:args ][end ]. head == Symbol (" ..." ) && d[:args ][end ]. args[] isa Symbol
22
+ args_symbol = d[:args ][end ]. args[]
23
+
24
+ fdefs = Expr (:block )
25
+
26
+ for i in 1 : n- 1
27
+ di = deepcopy (d)
28
+ pop! (di[:args ])
29
+ args = Tuple (gensym (" arg$j " ) for j in 1 : i)
30
+ Ts = Tuple (gensym (" T$j " ) for j in 1 : i)
31
+
32
+ args_with_Ts = ((arg, T) -> :($ arg :: $T )). (args, Ts)
33
+
34
+ di[:whereparams ] = (di[:whereparams ]. .. , Ts... )
35
+
36
+ push! (di[:args ], args_with_Ts... )
37
+ pushfirst! (di[:body ]. args, :($ args_symbol = $ (Expr (:tuple , args... ))))
38
+ cfdef = combinedef (di)
39
+ mcfdef = isempty (macros) ? cfdef : foldr ((m,f) -> Expr (:macrocall , m, nothing , f), macros, init= cfdef)
40
+ push! (fdefs. args, mcfdef)
41
+ end
42
+
43
+ di = deepcopy (d)
44
+ pop! (di[:args ])
45
+ args = tuple ((gensym () for j in 1 : n). .. , :($ (gensym (" args" )). .. ))
46
+ Ts = Tuple (gensym (" T$j " ) for j in 1 : n)
47
+
48
+ args_with_Ts = (((arg, T) -> :($ arg :: $T )). (args[1 : end - 1 ], Ts). .. , args[end ])
49
+
50
+ di[:whereparams ] = (di[:whereparams ]. .. , Ts... )
51
+
52
+ push! (di[:args ], args_with_Ts... )
53
+ pushfirst! (di[:body ]. args, :($ args_symbol = $ (Expr (:tuple , args... ))))
54
+
55
+ cfdef = combinedef (di)
56
+ mcfdef = isempty (macros) ? cfdef : foldr ((m,f) -> Expr (:macrocall , m, nothing , f), macros, init= cfdef)
57
+ push! (fdefs. args, mcfdef)
58
+
59
+ esc (fdefs)
60
+ end
61
+
6
62
using Cassette: overdub, Context, nametype, similarcontext
7
63
8
64
Cassette. @context DualContext
0 commit comments