Skip to content

Commit d278a85

Browse files
authored
Import stage2 demand-driven forward mode (#95)
This imports a ripped-out version of the demand-driven AD code from CedarSim and hooks it into the ADInterpreter. Starting with this code has the advantage that it is working-ish, but the disadvantage that it doesn't really interact with the rest of Diffractor yet. Still, I think it's a reasonable point to start. I'm doing this as a separate commit, so we can keep better track of the subsequent refactoring.
1 parent 29a14f9 commit d278a85

File tree

13 files changed

+767
-141
lines changed

13 files changed

+767
-141
lines changed

Manifest.toml

Lines changed: 357 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,357 @@
1+
# This file is machine-generated - editing it directly is not advised
2+
3+
julia_version = "1.10.0-DEV"
4+
manifest_format = "2.0"
5+
project_hash = "f6209327c3bf3625f9bce3952e420a70ebd8af82"
6+
7+
[[deps.AbstractTrees]]
8+
git-tree-sha1 = "52b3b436f8f73133d7bc3a6c71ee7ed6ab2ab754"
9+
uuid = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
10+
version = "0.4.3"
11+
12+
[[deps.Adapt]]
13+
deps = ["LinearAlgebra"]
14+
git-tree-sha1 = "195c5505521008abea5aee4f96930717958eac6f"
15+
uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
16+
version = "3.4.0"
17+
18+
[[deps.ArgTools]]
19+
uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f"
20+
version = "1.1.1"
21+
22+
[[deps.Artifacts]]
23+
uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"
24+
25+
[[deps.Base64]]
26+
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
27+
28+
[[deps.ChainRules]]
29+
deps = ["Adapt", "ChainRulesCore", "Compat", "Distributed", "GPUArraysCore", "IrrationalConstants", "LinearAlgebra", "Random", "RealDot", "SparseArrays", "Statistics", "StructArrays"]
30+
git-tree-sha1 = "99a39b0f807499510e2ea14b0eef8422082aa372"
31+
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
32+
version = "1.46.0"
33+
34+
[[deps.ChainRulesCore]]
35+
deps = ["Compat", "LinearAlgebra", "SparseArrays"]
36+
git-tree-sha1 = "e7ff6cadf743c098e08fca25c91103ee4303c9bb"
37+
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
38+
version = "1.15.6"
39+
40+
[[deps.ChangesOfVariables]]
41+
deps = ["ChainRulesCore", "LinearAlgebra", "Test"]
42+
git-tree-sha1 = "38f7a08f19d8810338d4f5085211c7dfa5d5bdd8"
43+
uuid = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0"
44+
version = "0.1.4"
45+
46+
[[deps.CodeTracking]]
47+
deps = ["InteractiveUtils", "UUIDs"]
48+
git-tree-sha1 = "3bf60ba2fae10e10f70d53c070424e40a820dac2"
49+
uuid = "da1fd8a2-8d9e-5ec2-8556-3022fb5608a2"
50+
version = "1.1.2"
51+
52+
[[deps.Combinatorics]]
53+
git-tree-sha1 = "08c8b6831dc00bfea825826be0bc8336fc369860"
54+
uuid = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
55+
version = "1.0.2"
56+
57+
[[deps.Compat]]
58+
deps = ["Dates", "LinearAlgebra", "UUIDs"]
59+
git-tree-sha1 = "00a2cccc7f098ff3b66806862d275ca3db9e6e5a"
60+
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
61+
version = "4.5.0"
62+
63+
[[deps.CompilerSupportLibraries_jll]]
64+
deps = ["Artifacts", "Libdl"]
65+
uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae"
66+
version = "1.0.1+0"
67+
68+
[[deps.Cthulhu]]
69+
deps = ["CodeTracking", "FoldingTrees", "InteractiveUtils", "Preferences", "REPL", "UUIDs", "Unicode"]
70+
git-tree-sha1 = "e31248559b7861339d09086e7bc5597898ae7a47"
71+
uuid = "f68482b8-f384-11e8-15f7-abe071a5a75f"
72+
version = "2.7.6"
73+
74+
[[deps.DataAPI]]
75+
git-tree-sha1 = "e8119c1a33d267e16108be441a287a6981ba1630"
76+
uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
77+
version = "1.14.0"
78+
79+
[[deps.DataStructures]]
80+
deps = ["Compat", "InteractiveUtils", "OrderedCollections"]
81+
git-tree-sha1 = "d1fff3a548102f48987a52a2e0d114fa97d730f0"
82+
uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
83+
version = "0.18.13"
84+
85+
[[deps.DataValueInterfaces]]
86+
git-tree-sha1 = "bfc1187b79289637fa0ef6d4436ebdfe6905cbd6"
87+
uuid = "e2d170a0-9d28-54be-80f0-106bbe20a464"
88+
version = "1.0.0"
89+
90+
[[deps.Dates]]
91+
deps = ["Printf"]
92+
uuid = "ade2ca70-3891-5945-98fb-dc099432e06a"
93+
94+
[[deps.Distributed]]
95+
deps = ["Random", "Serialization", "Sockets"]
96+
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
97+
98+
[[deps.DocStringExtensions]]
99+
deps = ["LibGit2"]
100+
git-tree-sha1 = "2fb1e02f2b635d0845df5d7c167fec4dd739b00d"
101+
uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
102+
version = "0.9.3"
103+
104+
[[deps.Downloads]]
105+
deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"]
106+
uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6"
107+
version = "1.6.0"
108+
109+
[[deps.FileWatching]]
110+
uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee"
111+
112+
[[deps.FoldingTrees]]
113+
deps = ["AbstractTrees", "REPL"]
114+
git-tree-sha1 = "d94efd85f2fe192cdf664aa8b7c431592faed59e"
115+
uuid = "1eca21be-9b9b-4ed8-839a-6d8ae26b1781"
116+
version = "1.2.1"
117+
118+
[[deps.GPUArraysCore]]
119+
deps = ["Adapt"]
120+
git-tree-sha1 = "6872f5ec8fd1a38880f027a26739d42dcda6691f"
121+
uuid = "46192b85-c4d5-4398-a991-12ede77f4527"
122+
version = "0.1.2"
123+
124+
[[deps.InteractiveUtils]]
125+
deps = ["Markdown"]
126+
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
127+
128+
[[deps.InverseFunctions]]
129+
deps = ["Test"]
130+
git-tree-sha1 = "49510dfcb407e572524ba94aeae2fced1f3feb0f"
131+
uuid = "3587e190-3f89-42d0-90ee-14403ec27112"
132+
version = "0.1.8"
133+
134+
[[deps.IrrationalConstants]]
135+
git-tree-sha1 = "7fd44fd4ff43fc60815f8e764c0f352b83c49151"
136+
uuid = "92d709cd-6900-40b7-9082-c6be49f344b6"
137+
version = "0.1.1"
138+
139+
[[deps.IteratorInterfaceExtensions]]
140+
git-tree-sha1 = "a3f24677c21f5bbe9d2a714f95dcd58337fb2856"
141+
uuid = "82899510-4779-5014-852e-03e436cf321d"
142+
version = "1.0.0"
143+
144+
[[deps.LibCURL]]
145+
deps = ["LibCURL_jll", "MozillaCACerts_jll"]
146+
uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21"
147+
version = "0.6.3"
148+
149+
[[deps.LibCURL_jll]]
150+
deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"]
151+
uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0"
152+
version = "7.84.0+0"
153+
154+
[[deps.LibGit2]]
155+
deps = ["Base64", "NetworkOptions", "Printf", "SHA"]
156+
uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"
157+
158+
[[deps.LibSSH2_jll]]
159+
deps = ["Artifacts", "Libdl", "MbedTLS_jll"]
160+
uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8"
161+
version = "1.10.2+0"
162+
163+
[[deps.Libdl]]
164+
uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
165+
166+
[[deps.LinearAlgebra]]
167+
deps = ["Libdl", "OpenBLAS_jll", "libblastrampoline_jll"]
168+
uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
169+
170+
[[deps.LogExpFunctions]]
171+
deps = ["ChainRulesCore", "ChangesOfVariables", "DocStringExtensions", "InverseFunctions", "IrrationalConstants", "LinearAlgebra"]
172+
git-tree-sha1 = "946607f84feb96220f480e0422d3484c49c00239"
173+
uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
174+
version = "0.3.19"
175+
176+
[[deps.Logging]]
177+
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
178+
179+
[[deps.Markdown]]
180+
deps = ["Base64"]
181+
uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"
182+
183+
[[deps.MbedTLS_jll]]
184+
deps = ["Artifacts", "Libdl"]
185+
uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1"
186+
version = "2.28.0+0"
187+
188+
[[deps.Missings]]
189+
deps = ["DataAPI"]
190+
git-tree-sha1 = "f66bdc5de519e8f8ae43bdc598782d35a25b1272"
191+
uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28"
192+
version = "1.1.0"
193+
194+
[[deps.MozillaCACerts_jll]]
195+
uuid = "14a3606d-f60d-562e-9121-12d972cd8159"
196+
version = "2022.10.11"
197+
198+
[[deps.NetworkOptions]]
199+
uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908"
200+
version = "1.2.0"
201+
202+
[[deps.OffsetArrays]]
203+
deps = ["Adapt"]
204+
git-tree-sha1 = "f71d8950b724e9ff6110fc948dff5a329f901d64"
205+
uuid = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
206+
version = "1.12.8"
207+
208+
[[deps.OpenBLAS_jll]]
209+
deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"]
210+
uuid = "4536629a-c528-5b80-bd46-f80d51c5b363"
211+
version = "0.3.21+0"
212+
213+
[[deps.OrderedCollections]]
214+
git-tree-sha1 = "85f8e6578bf1f9ee0d11e7bb1b1456435479d47c"
215+
uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
216+
version = "1.4.1"
217+
218+
[[deps.Pkg]]
219+
deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"]
220+
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
221+
version = "1.8.0"
222+
223+
[[deps.Preferences]]
224+
deps = ["TOML"]
225+
git-tree-sha1 = "47e5f437cc0e7ef2ce8406ce1e7e24d44915f88d"
226+
uuid = "21216c6a-2e73-6563-6e65-726566657250"
227+
version = "1.3.0"
228+
229+
[[deps.Printf]]
230+
deps = ["Unicode"]
231+
uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7"
232+
233+
[[deps.REPL]]
234+
deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"]
235+
uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
236+
237+
[[deps.Random]]
238+
deps = ["SHA", "Serialization"]
239+
uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
240+
241+
[[deps.RealDot]]
242+
deps = ["LinearAlgebra"]
243+
git-tree-sha1 = "9f0a1b71baaf7650f4fa8a1d168c7fb6ee41f0c9"
244+
uuid = "c1ae055f-0cd5-4b69-90a6-9a35b1a98df9"
245+
version = "0.1.0"
246+
247+
[[deps.SHA]]
248+
uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"
249+
version = "0.7.0"
250+
251+
[[deps.Serialization]]
252+
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
253+
254+
[[deps.Sockets]]
255+
uuid = "6462fe0b-24de-5631-8697-dd941f90decc"
256+
257+
[[deps.SortingAlgorithms]]
258+
deps = ["DataStructures"]
259+
git-tree-sha1 = "a4ada03f999bd01b3a25dcaa30b2d929fe537e00"
260+
uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c"
261+
version = "1.1.0"
262+
263+
[[deps.SparseArrays]]
264+
deps = ["Libdl", "LinearAlgebra", "Random", "Serialization", "SuiteSparse_jll"]
265+
uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
266+
267+
[[deps.StaticArrays]]
268+
deps = ["LinearAlgebra", "Random", "StaticArraysCore", "Statistics"]
269+
git-tree-sha1 = "ffc098086f35909741f71ce21d03dadf0d2bfa76"
270+
uuid = "90137ffa-7385-5640-81b9-e52037218182"
271+
version = "1.5.11"
272+
273+
[[deps.StaticArraysCore]]
274+
git-tree-sha1 = "6b7ba252635a5eff6a0b0664a41ee140a1c9e72a"
275+
uuid = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
276+
version = "1.4.0"
277+
278+
[[deps.Statistics]]
279+
deps = ["LinearAlgebra", "SparseArrays"]
280+
uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
281+
version = "1.9.0"
282+
283+
[[deps.StatsAPI]]
284+
deps = ["LinearAlgebra"]
285+
git-tree-sha1 = "f9af7f195fb13589dd2e2d57fdb401717d2eb1f6"
286+
uuid = "82ae8749-77ed-4fe6-ae5f-f523153014b0"
287+
version = "1.5.0"
288+
289+
[[deps.StatsBase]]
290+
deps = ["DataAPI", "DataStructures", "LinearAlgebra", "LogExpFunctions", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "StatsAPI"]
291+
git-tree-sha1 = "d1bf48bfcc554a3761a133fe3a9bb01488e06916"
292+
uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
293+
version = "0.33.21"
294+
295+
[[deps.StructArrays]]
296+
deps = ["Adapt", "DataAPI", "GPUArraysCore", "StaticArraysCore", "Tables"]
297+
git-tree-sha1 = "b03a3b745aa49b566f128977a7dd1be8711c5e71"
298+
uuid = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
299+
version = "0.6.14"
300+
301+
[[deps.SuiteSparse_jll]]
302+
deps = ["Artifacts", "Libdl", "Pkg", "libblastrampoline_jll"]
303+
uuid = "bea87d4a-7f5b-5778-9afe-8cc45184846c"
304+
version = "5.10.1+0"
305+
306+
[[deps.TOML]]
307+
deps = ["Dates"]
308+
uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76"
309+
version = "1.0.3"
310+
311+
[[deps.TableTraits]]
312+
deps = ["IteratorInterfaceExtensions"]
313+
git-tree-sha1 = "c06b2f539df1c6efa794486abfb6ed2022561a39"
314+
uuid = "3783bdb8-4a98-5b6b-af9a-565f29a5fe9c"
315+
version = "1.0.1"
316+
317+
[[deps.Tables]]
318+
deps = ["DataAPI", "DataValueInterfaces", "IteratorInterfaceExtensions", "LinearAlgebra", "OrderedCollections", "TableTraits", "Test"]
319+
git-tree-sha1 = "c79322d36826aa2f4fd8ecfa96ddb47b174ac78d"
320+
uuid = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
321+
version = "1.10.0"
322+
323+
[[deps.Tar]]
324+
deps = ["ArgTools", "SHA"]
325+
uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e"
326+
version = "1.10.0"
327+
328+
[[deps.Test]]
329+
deps = ["InteractiveUtils", "Logging", "Random", "Serialization"]
330+
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
331+
332+
[[deps.UUIDs]]
333+
deps = ["Random", "SHA"]
334+
uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
335+
336+
[[deps.Unicode]]
337+
uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"
338+
339+
[[deps.Zlib_jll]]
340+
deps = ["Libdl"]
341+
uuid = "83775a58-1f1d-513f-b197-d71354ab007a"
342+
version = "1.2.13+0"
343+
344+
[[deps.libblastrampoline_jll]]
345+
deps = ["Artifacts", "Libdl"]
346+
uuid = "8e850b90-86db-534c-a0d3-1478176c7d93"
347+
version = "5.2.0+0"
348+
349+
[[deps.nghttp2_jll]]
350+
deps = ["Artifacts", "Libdl"]
351+
uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d"
352+
version = "1.48.0+0"
353+
354+
[[deps.p7zip_jll]]
355+
deps = ["Artifacts", "Libdl"]
356+
uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0"
357+
version = "17.4.0+0"

src/Diffractor.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ using StructArrays
44

55
export ∂⃖, gradient
66

7+
const CC = Core.Compiler
8+
79
include("runtime.jl")
810
include("interface.jl")
911
include("utils.jl")
@@ -21,7 +23,11 @@ include("stage2/interpreter.jl")
2123
include("stage2/lattice.jl")
2224
include("stage2/abstractinterpret.jl")
2325
include("stage2/tfuncs.jl")
26+
include("stage2/forward.jl")
2427

28+
include("codegen/forward.jl")
29+
include("analysis/forward.jl")
30+
include("codegen/forward_demand.jl")
2531
include("codegen/reverse.jl")
2632

2733
include("extra_rules.jl")

0 commit comments

Comments
 (0)