Skip to content

Commit 0a121ea

Browse files
authored
Implement higher order truncation (#106)
* Implement higher order truncation * Bump deps and allow bad inference for the moment
1 parent 0241872 commit 0a121ea

File tree

4 files changed

+71
-46
lines changed

4 files changed

+71
-46
lines changed

Manifest.toml

Lines changed: 39 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,18 @@
22

33
julia_version = "1.10.0-DEV"
44
manifest_format = "2.0"
5-
project_hash = "f6209327c3bf3625f9bce3952e420a70ebd8af82"
5+
project_hash = "dd9f7fdb915ecd02d7d92d5adfe6109efe7a527e"
66

77
[[deps.AbstractTrees]]
8-
git-tree-sha1 = "52b3b436f8f73133d7bc3a6c71ee7ed6ab2ab754"
8+
git-tree-sha1 = "faa260e4cb5aba097a73fab382dd4b5819d8ec8c"
99
uuid = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
10-
version = "0.4.3"
10+
version = "0.4.4"
1111

1212
[[deps.Adapt]]
1313
deps = ["LinearAlgebra"]
14-
git-tree-sha1 = "195c5505521008abea5aee4f96930717958eac6f"
14+
git-tree-sha1 = "0310e08cb19f5da31d08341c6120c047598f5b9c"
1515
uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
16-
version = "3.4.0"
16+
version = "3.5.0"
1717

1818
[[deps.ArgTools]]
1919
uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f"
@@ -27,27 +27,21 @@ uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
2727

2828
[[deps.ChainRules]]
2929
deps = ["Adapt", "ChainRulesCore", "Compat", "Distributed", "GPUArraysCore", "IrrationalConstants", "LinearAlgebra", "Random", "RealDot", "SparseArrays", "Statistics", "StructArrays"]
30-
git-tree-sha1 = "99a39b0f807499510e2ea14b0eef8422082aa372"
30+
git-tree-sha1 = "c46adabdd0348f0ee8de91142cfc4a72a613ac0a"
3131
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
32-
version = "1.46.0"
32+
version = "1.46.1"
3333

3434
[[deps.ChainRulesCore]]
3535
deps = ["Compat", "LinearAlgebra", "SparseArrays"]
36-
git-tree-sha1 = "e7ff6cadf743c098e08fca25c91103ee4303c9bb"
36+
git-tree-sha1 = "c6d890a52d2c4d55d326439580c3b8d0875a77d9"
3737
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
38-
version = "1.15.6"
39-
40-
[[deps.ChangesOfVariables]]
41-
deps = ["ChainRulesCore", "LinearAlgebra", "Test"]
42-
git-tree-sha1 = "38f7a08f19d8810338d4f5085211c7dfa5d5bdd8"
43-
uuid = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0"
44-
version = "0.1.4"
38+
version = "1.15.7"
4539

4640
[[deps.CodeTracking]]
4741
deps = ["InteractiveUtils", "UUIDs"]
48-
git-tree-sha1 = "3bf60ba2fae10e10f70d53c070424e40a820dac2"
42+
git-tree-sha1 = "0e5c14c3bb8a61b3d53b2c0620570c332c8d0663"
4943
uuid = "da1fd8a2-8d9e-5ec2-8556-3022fb5608a2"
50-
version = "1.1.2"
44+
version = "1.2.0"
5145

5246
[[deps.Combinatorics]]
5347
git-tree-sha1 = "08c8b6831dc00bfea825826be0bc8336fc369860"
@@ -56,20 +50,20 @@ version = "1.0.2"
5650

5751
[[deps.Compat]]
5852
deps = ["Dates", "LinearAlgebra", "UUIDs"]
59-
git-tree-sha1 = "00a2cccc7f098ff3b66806862d275ca3db9e6e5a"
53+
git-tree-sha1 = "61fdd77467a5c3ad071ef8277ac6bd6af7dd4c04"
6054
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
61-
version = "4.5.0"
55+
version = "4.6.0"
6256

6357
[[deps.CompilerSupportLibraries_jll]]
6458
deps = ["Artifacts", "Libdl"]
6559
uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae"
66-
version = "1.0.1+0"
60+
version = "1.0.2+0"
6761

6862
[[deps.Cthulhu]]
6963
deps = ["CodeTracking", "FoldingTrees", "InteractiveUtils", "Preferences", "REPL", "UUIDs", "Unicode"]
70-
git-tree-sha1 = "e31248559b7861339d09086e7bc5597898ae7a47"
64+
git-tree-sha1 = "6275f27473e7d3e91ab4892b6128ab68b3b8098f"
7165
uuid = "f68482b8-f384-11e8-15f7-abe071a5a75f"
72-
version = "2.7.6"
66+
version = "2.7.8"
7367

7468
[[deps.DataAPI]]
7569
git-tree-sha1 = "e8119c1a33d267e16108be441a287a6981ba1630"
@@ -117,20 +111,14 @@ version = "1.2.1"
117111

118112
[[deps.GPUArraysCore]]
119113
deps = ["Adapt"]
120-
git-tree-sha1 = "6872f5ec8fd1a38880f027a26739d42dcda6691f"
114+
git-tree-sha1 = "57f7cde02d7a53c9d1d28443b9f11ac5fbe7ebc9"
121115
uuid = "46192b85-c4d5-4398-a991-12ede77f4527"
122-
version = "0.1.2"
116+
version = "0.1.3"
123117

124118
[[deps.InteractiveUtils]]
125119
deps = ["Markdown"]
126120
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
127121

128-
[[deps.InverseFunctions]]
129-
deps = ["Test"]
130-
git-tree-sha1 = "49510dfcb407e572524ba94aeae2fced1f3feb0f"
131-
uuid = "3587e190-3f89-42d0-90ee-14403ec27112"
132-
version = "0.1.8"
133-
134122
[[deps.IrrationalConstants]]
135123
git-tree-sha1 = "7fd44fd4ff43fc60815f8e764c0f352b83c49151"
136124
uuid = "92d709cd-6900-40b7-9082-c6be49f344b6"
@@ -168,10 +156,20 @@ deps = ["Libdl", "OpenBLAS_jll", "libblastrampoline_jll"]
168156
uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
169157

170158
[[deps.LogExpFunctions]]
171-
deps = ["ChainRulesCore", "ChangesOfVariables", "DocStringExtensions", "InverseFunctions", "IrrationalConstants", "LinearAlgebra"]
172-
git-tree-sha1 = "946607f84feb96220f480e0422d3484c49c00239"
159+
deps = ["DocStringExtensions", "IrrationalConstants", "LinearAlgebra"]
160+
git-tree-sha1 = "680e733c3a0a9cea9e935c8c2184aea6a63fa0b5"
173161
uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
174-
version = "0.3.19"
162+
version = "0.3.21"
163+
164+
[deps.LogExpFunctions.extensions]
165+
ChainRulesCoreExt = "ChainRulesCore"
166+
ChangesOfVariablesExt = "ChangesOfVariables"
167+
InverseFunctionsExt = "InverseFunctions"
168+
169+
[deps.LogExpFunctions.weakdeps]
170+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
171+
ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0"
172+
InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112"
175173

176174
[[deps.Logging]]
177175
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
@@ -201,9 +199,9 @@ version = "1.2.0"
201199

202200
[[deps.OffsetArrays]]
203201
deps = ["Adapt"]
204-
git-tree-sha1 = "f71d8950b724e9ff6110fc948dff5a329f901d64"
202+
git-tree-sha1 = "82d7c9e310fe55aa54996e6f7f94674e2a38fcb4"
205203
uuid = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
206-
version = "1.12.8"
204+
version = "1.12.9"
207205

208206
[[deps.OpenBLAS_jll]]
209207
deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"]
@@ -218,7 +216,7 @@ version = "1.4.1"
218216
[[deps.Pkg]]
219217
deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"]
220218
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
221-
version = "1.8.0"
219+
version = "1.10.0"
222220

223221
[[deps.Preferences]]
224222
deps = ["TOML"]
@@ -263,12 +261,13 @@ version = "1.1.0"
263261
[[deps.SparseArrays]]
264262
deps = ["Libdl", "LinearAlgebra", "Random", "Serialization", "SuiteSparse_jll"]
265263
uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
264+
version = "1.10.0"
266265

267266
[[deps.StaticArrays]]
268267
deps = ["LinearAlgebra", "Random", "StaticArraysCore", "Statistics"]
269-
git-tree-sha1 = "ffc098086f35909741f71ce21d03dadf0d2bfa76"
268+
git-tree-sha1 = "129703d62117c374c4f2db6d13a027741c46eafd"
270269
uuid = "90137ffa-7385-5640-81b9-e52037218182"
271-
version = "1.5.11"
270+
version = "1.5.13"
272271

273272
[[deps.StaticArraysCore]]
274273
git-tree-sha1 = "6b7ba252635a5eff6a0b0664a41ee140a1c9e72a"
@@ -301,7 +300,7 @@ version = "0.6.14"
301300
[[deps.SuiteSparse_jll]]
302301
deps = ["Artifacts", "Libdl", "Pkg", "libblastrampoline_jll"]
303302
uuid = "bea87d4a-7f5b-5778-9afe-8cc45184846c"
304-
version = "5.10.1+0"
303+
version = "5.10.1+6"
305304

306305
[[deps.TOML]]
307306
deps = ["Dates"]
@@ -344,7 +343,7 @@ version = "1.2.13+0"
344343
[[deps.libblastrampoline_jll]]
345344
deps = ["Artifacts", "Libdl"]
346345
uuid = "8e850b90-86db-534c-a0d3-1478176c7d93"
347-
version = "5.2.0+0"
346+
version = "5.4.0+0"
348347

349348
[[deps.nghttp2_jll]]
350349
deps = ["Artifacts", "Libdl"]

src/codegen/forward_demand.jl

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
using Core.Compiler: IRInterpretationState, construct_postdomtree, PiNode,
2-
is_known_call, argextype, postdominates, userefs
2+
is_known_call, argextype, postdominates, userefs, PhiCNode, UpsilonNode
33

44
#=
55
function forward_diff!(ir::IRCode, interp, irsv::IRInterpretationState, to_diff::Vector{Pair{SSAValue, Int}}; custom_diff! = (args...)->nothing, diff_cache=Dict{SSAValue, SSAValue}())
@@ -205,9 +205,11 @@ function forward_diff_no_inf!(ir::IRCode, interp, mi::MethodInstance, world, to_
205205
if argorder != order
206206
@assert order < argorder
207207
return get!(truncation_map, arg=>order) do
208-
# TODO: Other orders
209-
@assert order == 0
210-
insert_node!(ir, arg, NewInstruction(Expr(:call, primal, arg), Any), #=attach_after=#true)
208+
if order == 0
209+
insert_node!(ir, arg, NewInstruction(Expr(:call, primal, arg), Any), #=attach_after=#true)
210+
else
211+
insert_node!(ir, arg, NewInstruction(Expr(:call, truncate, arg, Val{order}()), Any), #=attach_after=#true)
212+
end
211213
end
212214
end
213215
return arg
@@ -247,13 +249,25 @@ function forward_diff_no_inf!(ir::IRCode, interp, mi::MethodInstance, world, to_
247249
elseif isexpr(stmt, :call)
248250
inst[:inst] = Expr(:call, ∂☆{order}(), map(arg->maparg(arg, SSAValue(ssa), order), stmt.args)...)
249251
inst[:type] = Any
250-
else
252+
elseif isa(stmt, PiNode)
253+
# TODO: New PiNode that discriminates based on primal?
254+
inst[:inst] = maparg(stmt.val, SSAValue(ssa), order)
255+
inst[:type] = Any
256+
elseif isa(stmt, GlobalRef)
257+
inst[:inst] = maparg(stmt, SSAValue(ssa), order)
258+
inst[:type] = Any
259+
elseif isa(stmt, Expr) || isa(stmt, PhiNode) || isa(stmt, PhiCNode) ||
260+
isa(stmt, UpsilonNode) || isa(stmt, GotoIfNot) || isa(stmt, QuoteNode) || isa(stmt, Argument)
251261
urs = userefs(stmt)
252262
for ur in urs
253263
ur[] = maparg(ur[], SSAValue(ssa), order)
254264
end
255265
inst[:inst] = urs[]
256266
inst[:type] = Any
267+
else
268+
val = ZeroBundle{order}(inst[:inst])
269+
inst[:inst] = val
270+
inst[:type] = Const(val)
257271
end
258272
end
259273
end

src/tangent.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,18 @@ function Base.getindex(tb::TaylorBundle, tti::CanonicalTangentIndex)
230230
tb.tangent.coeffs[count_ones(tti.i)]
231231
end
232232

233+
function truncate(tt::TaylorTangent, order::Val{N}) where {N}
234+
TaylorTangent(tt.coeffs[1:N])
235+
end
236+
237+
function truncate(ut::UniformTangent, order::Val)
238+
ut
239+
end
240+
241+
function truncate(tb::TangentBundle, order::Val)
242+
_TangentBundle(order, tb.primal, truncate(tb.tangent, order))
243+
end
244+
233245
const UniformBundle{N, B, U} = TangentBundle{N, B, UniformTangent{U}}
234246
UniformBundle{N, B, U}(primal::B, partial::U) where {N,B,U} = _TangentBundle(Val{N}(), primal, UniformTangent{U}(partial))
235247
UniformBundle{N, B, U}(primal::B) where {N,B,U} = _TangentBundle(Val{N}(), primal, UniformTangent{U}(U.instance))

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ isa_control_flow(::Type{T}, x) where {T} = isa(x, T) ? x : T(x)
9393
let var"'" = Diffractor.PrimeDerivativeBack
9494
# Integration tests
9595
@test @inferred(sin'(1.0)) == cos(1.0)
96-
@test @inferred(sin''(1.0)) == -sin(1.0)
96+
@test sin''(1.0) == -sin(1.0)
9797
@test sin'''(1.0) == -cos(1.0)
9898
@test sin''''(1.0) == sin(1.0)
9999
@test sin'''''(1.0) == cos(1.0)

0 commit comments

Comments
 (0)