Skip to content

Commit 29a14f9

Browse files
authored
Very WIP: Stage2 revival (#78)
* Very WIP: Stage2 revival Needs some corresponding changes in Cthulhu WIP WIP WIP WIP fix inferred type and return handling of rrule calls (#83) add custom Cthulhu toggle to show ad-transformed code (#85) add inlining policy for ADInterpreter (#84) fixup * Fix issues in untyped transform
1 parent bc4773a commit 29a14f9

File tree

8 files changed

+1047
-678
lines changed

8 files changed

+1047
-678
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@ version = "0.1.1"
77
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
88
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
99
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
10+
Cthulhu = "f68482b8-f384-11e8-15f7-abe071a5a75f"
1011
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
12+
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
1113
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1214
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1315
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"

scratch/stage2.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
using Cthulhu
2+
using Diffractor
3+
using Diffractor: ADInterpreter
4+
using Diffractor: var"'", ∂⃖
5+
6+
function foo(x)
7+
sin(x)
8+
end
9+
bar(x) = foo(x)
10+
11+
diffsin(x) = bar'(x)
12+
13+
diffsin(1.0)
14+
15+
function do_the_thing()
16+
interp = ADInterpreter()
17+
mi = Cthulhu.get_specialization(Tuple{map(Core.Typeof, (diffsin, 2.0))...})
18+
Cthulhu.do_typeinf!(interp, mi)
19+
(interp, mi)
20+
end
21+
(interp, mi) = do_the_thing();
22+
Diffractor.codegen(interp, Diffractor.ADCursor(0, mi))(1.0)

src/Diffractor.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,13 @@ include("stage1/recurse_fwd.jl")
1717
include("stage1/mixed.jl")
1818
include("stage1/broadcast.jl")
1919

20+
include("stage2/interpreter.jl")
21+
include("stage2/lattice.jl")
22+
include("stage2/abstractinterpret.jl")
23+
include("stage2/tfuncs.jl")
24+
25+
include("codegen/reverse.jl")
26+
2027
include("extra_rules.jl")
2128

2229
include("higher_fwd_rules.jl")

0 commit comments

Comments
 (0)