Skip to content

Commit 264aa8b

Browse files
committed
Hook up Eras to forward code gen
1 parent c0c77c8 commit 264aa8b

File tree

3 files changed

+9
-5
lines changed

3 files changed

+9
-5
lines changed

src/stage1/forward.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,7 @@ end
169169
#E eras mode, this controls if we should Error if it isn't Taylor. This should be a Bool
170170
struct ∂☆internal{N, E}; end
171171
struct ∂☆recurse{N, E}; end
172+
∂☆recurse{N}() where N = ∂☆recurse{N,false}
172173
struct ∂☆shuffle{N}; end
173174

174175
function shuffle_base(r)

src/stage1/recurse_fwd.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,9 @@ function ∂☆builtin((f_bundle, args...))
7272
throw(DomainError(f, "No `ChainRulesCore.frule` found for the built-in function `$sig`"))
7373
end
7474

75-
function fwd_transform(ci::CodeInfo, args...)
75+
function fwd_transform(ci, mi, nargs, N, E)
7676
newci = copy(ci)
77-
fwd_transform!(newci, args...)
77+
fwd_transform!(newci, mi, nargs, N, E)
7878
return newci
7979
end
8080

@@ -214,7 +214,7 @@ function fwd_transform!(ci::CodeInfo, mi::MethodInstance, nargs::Int, N::Int, E)
214214
end
215215

216216
function perform_fwd_transform(world::UInt, source::LineNumberNode,
217-
@nospecialize(ff::Type{∂☆recurse{N}}), @nospecialize(args)) where {N}
217+
@nospecialize(ff::Type{∂☆recurse{N,E}}), @nospecialize(args)) where {N,E}
218218
if all(x->x <: ZeroBundle, args)
219219
return generate_lambda_ex(world, source,
220220
Core.svec(:ff, :args), Core.svec(), :(∂☆passthrough(args)))
@@ -237,7 +237,7 @@ function perform_fwd_transform(world::UInt, source::LineNumberNode,
237237
mi = Core.Compiler.specialize_method(match)
238238
ci = Core.Compiler.retrieve_code_info(mi, world)
239239

240-
return fwd_transform(ci, mi, length(args)-1, N)
240+
return fwd_transform(ci, mi, length(args)-1, N, E)
241241
end
242242

243243
@eval function (ff::∂☆recurse)(args...)

src/stage1/termination.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,10 @@ which(Tuple{∂⃖{N}, ∂⃖{1}, Vararg{Any}} where {N}).recursion_relation = f
4444
isa(Base.unwrap_unionall(new_sig.parameters[1].parameters[1]), Int)
4545
end
4646

47-
for (;method) in Base._methods_by_ftype(Tuple{Diffractor.∂☆recurse{N}, Vararg{Any}} where {N}, nothing, -1, get_world_counter())
47+
for (;method) in [
48+
Base._methods_by_ftype(Tuple{Diffractor.∂☆recurse{N}, Vararg{Any}} where {N}, nothing, -1, get_world_counter());
49+
Base._methods_by_ftype(Tuple{Diffractor.∂☆recurse{N, E}, Vararg{Any}} where {N, E}, nothing, -1, get_world_counter());
50+
]
4851
method.recursion_relation = function (method1, method2, parent_sig, new_sig)
4952
# Recursion from a higher to a lower order is always allowed
5053
parent_order = parent_sig.parameters[1].parameters[1]

0 commit comments

Comments
 (0)