-
Notifications
You must be signed in to change notification settings - Fork 25
feat: backend switching for Mooncake #768
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 24 commits
1a389a6
08b176a
ba0c9e6
1340d92
2ce1ee2
08de6df
84f27c9
2e95299
13233e5
1e8df98
afdddd4
233c312
7a07127
f3e436d
6a0d937
e543958
2472ecc
c63c956
36da036
d2b5a8c
c389a80
b4fe0f8
ec4b75d
0f0b9fc
3c5f99e
d94f146
c982f46
749fea5
9e5ecfd
1e85f17
ff5c4e2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,268 @@ | ||
@is_primitive MinimalCtx Tuple{DI.DifferentiateWith,<:Union{Number,AbstractArray,Tuple}} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I find it a bit weird to have this union of |
||
|
||
# nested vectors (eg. [[1.0]]), Tuples (eg. ((1.0,),)) or similar (eg. [(1.0,)]) primal types are not supported by DI yet ! | ||
# This is because basis construction (DI.basis) does not have overloads for these types. | ||
# For details, refer commented out test cases to see where the pullback creation fails. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The issue is that we're testing There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Resolved by removing tuples for the time being
gdalle marked this conversation as resolved.
Show resolved
Hide resolved
|
||
function Mooncake.rrule!!( | ||
dw::CoDual{<:DI.DifferentiateWith}, x::Union{CoDual{<:Number},CoDual{<:Tuple}} | ||
) | ||
primal_func = primal(dw) | ||
AstitvaAggarwal marked this conversation as resolved.
Show resolved
Hide resolved
|
||
primal_x = primal(x) | ||
(; f, backend) = primal_func | ||
y = zero_fcodual(f(primal_x)) | ||
gdalle marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# output is a vector, so we need to use the vector pullback | ||
AstitvaAggarwal marked this conversation as resolved.
Show resolved
Hide resolved
|
||
function pullback_array!!(dy::NoRData) | ||
tx = DI.pullback(f, backend, primal_x, (y.dx,)) | ||
@assert rdata(only(tx)) isa rdata_type(tangent_type(typeof(primal_x))) | ||
return NoRData(), rdata(only(tx)) | ||
end | ||
|
||
# output is a scalar, so we can use the scalar pullback | ||
function pullback_scalar!!(dy::Number) | ||
tx = DI.pullback(f, backend, primal_x, (dy,)) | ||
@assert rdata(only(tx)) isa rdata_type(tangent_type(typeof(primal_x))) | ||
return NoRData(), rdata(only(tx)) | ||
end | ||
|
||
# output is a Tuple, NTuple | ||
function pullback_tuple!!(dy::Tuple) | ||
tx = DI.pullback(f, backend, primal_x, (dy,)) | ||
@assert rdata(only(tx)) isa rdata_type(tangent_type(typeof(primal_x))) | ||
return NoRData(), rdata(only(tx)) | ||
end | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This only works for tuples of numbers, right? With a tuple of arrays for instance, it would fail? Perhaps it would be best for us to just remove support for tuples completely at first. |
||
|
||
# inputs are non Differentiable | ||
function pullback_nodiff!!(dy::NoRData) | ||
@assert tangent_type(typeof(primal(x))) <: NoTangent | ||
return NoRData(), dy | ||
end | ||
|
||
pullback = if tangent_type(typeof(primal(x))) <: NoTangent | ||
pullback_nodiff!! | ||
elseif typeof(primal(y)) <: Number | ||
gdalle marked this conversation as resolved.
Show resolved
Hide resolved
|
||
pullback_scalar!! | ||
elseif typeof(primal(y)) <: Array | ||
gdalle marked this conversation as resolved.
Show resolved
Hide resolved
|
||
pullback_array!! | ||
elseif typeof(primal(y)) <: Tuple | ||
gdalle marked this conversation as resolved.
Show resolved
Hide resolved
|
||
pullback_tuple!! | ||
else | ||
error( | ||
Check warning on line 50 in DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl
|
||
"For the function type $(typeof(primal_func)) and input type $(typeof(primal_x)), the primal type $(typeof(primal(y))) is currently not supported.", | ||
) | ||
gdalle marked this conversation as resolved.
Show resolved
Hide resolved
|
||
end | ||
|
||
return y, pullback | ||
end | ||
|
||
function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:AbstractArray}) | ||
primal_func = primal(dw) | ||
primal_x = primal(x) | ||
fdata_arg = x.dx | ||
(; f, backend) = primal_func | ||
y = zero_fcodual(f(primal_x)) | ||
|
||
# output is a vector, so we need to use the vector pullback | ||
function pullback_array!!(dy::NoRData) | ||
tx = DI.pullback(f, backend, primal_x, (y.dx,)) | ||
@assert rdata(first(only(tx))) isa rdata_type(tangent_type(typeof(first(primal_x)))) | ||
fdata_arg .+= only(tx) | ||
return NoRData(), dy | ||
end | ||
|
||
# output is a scalar, so we can use the scalar pullback | ||
function pullback_scalar!!(dy::Number) | ||
tx = DI.pullback(f, backend, primal_x, (dy,)) | ||
@assert rdata(first(only(tx))) isa rdata_type(tangent_type(typeof(first(primal_x)))) | ||
fdata_arg .+= only(tx) | ||
return NoRData(), NoRData() | ||
end | ||
|
||
# output is a Tuple, NTuple | ||
function pullback_tuple!!(dy::Tuple) | ||
tx = DI.pullback(f, backend, primal_x, (dy,)) | ||
@assert rdata(first(only(tx))) isa rdata_type(tangent_type(typeof(first(primal_x)))) | ||
fdata_arg .+= only(tx) | ||
return NoRData(), NoRData() | ||
end | ||
gdalle marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# inputs are non Differentiable | ||
function pullback_nodiff!!(dy::NoRData) | ||
@assert tangent_type(typeof(primal(x))) <: Vector{NoTangent} | ||
return NoRData(), dy | ||
end | ||
|
||
pullback = if tangent_type(typeof(primal(x))) <: Vector{NoTangent} | ||
pullback_nodiff!! | ||
elseif typeof(primal(y)) <: Number | ||
gdalle marked this conversation as resolved.
Show resolved
Hide resolved
|
||
pullback_scalar!! | ||
elseif typeof(primal(y)) <: AbstractArray | ||
gdalle marked this conversation as resolved.
Show resolved
Hide resolved
|
||
pullback_array!! | ||
elseif typeof(primal(y)) <: Tuple | ||
gdalle marked this conversation as resolved.
Show resolved
Hide resolved
|
||
pullback_tuple!! | ||
else | ||
error( | ||
Check warning on line 104 in DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl
|
||
"For the function type $(typeof(primal_func)) and input type $(typeof(primal_x)), the primal type $(typeof(primal(y))) is currently not supported.", | ||
) | ||
gdalle marked this conversation as resolved.
Show resolved
Hide resolved
|
||
end | ||
|
||
return y, pullback | ||
end | ||
|
||
gdalle marked this conversation as resolved.
Show resolved
Hide resolved
|
||
function Mooncake.generate_derived_rrule!!_test_cases(rng_ctor, ::Val{:diffwith}) | ||
return Any[], Any[] | ||
end | ||
gdalle marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
function Mooncake.generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:diffwith}) | ||
test_cases = reduce( | ||
vcat, | ||
map([(x) -> DI.DifferentiateWith(x, DI.AutoFiniteDiff())]) do F | ||
map([Float64, Float32]) do P | ||
return Any[ | ||
# (false, :none, nothing, F(identity), ((1.0,),)), # (DI.basis fails for this, correct it!) | ||
# (false, :none, nothing, F(identity), [[1.0]]), # (DI.basis fails for this, correct it!) | ||
(false, :stability_and_allocs, nothing, F(cosh), P(0.3)), | ||
(false, :stability_and_allocs, nothing, F(sinh), P(0.3)), | ||
( | ||
false, | ||
:stability_and_allocs, | ||
nothing, | ||
F(Base.FastMath.exp10_fast), | ||
P(0.5), | ||
), | ||
( | ||
false, | ||
:stability_and_allocs, | ||
nothing, | ||
F(Base.FastMath.exp2_fast), | ||
P(0.5), | ||
), | ||
( | ||
false, | ||
:stability_and_allocs, | ||
nothing, | ||
F(Base.FastMath.exp_fast), | ||
P(5.0), | ||
), | ||
(false, :stability, nothing, F(copy), rand(Int32, 5)), | ||
] | ||
end | ||
end..., | ||
) | ||
|
||
map([(x) -> DI.DifferentiateWith(x, DI.AutoFiniteDiff())]) do F | ||
push!( | ||
test_cases, | ||
Any[ | ||
(false, :stability, nothing, copy, randn(5, 4)), | ||
( | ||
# Check that Core._apply_iterate gets lifted to _apply_iterate_equivalent. | ||
false, | ||
:stability, | ||
nothing, | ||
F(x -> +(x...)), | ||
randn(33), | ||
), | ||
( | ||
false, | ||
:stability, | ||
nothing, | ||
(F( | ||
function (x) | ||
rx = Ref(x) | ||
return Base.pointerref( | ||
Base.bitcast(Ptr{Float64}, pointer_from_objref(rx)), 1, 1 | ||
) | ||
end, | ||
)), | ||
5.0, | ||
), | ||
# (false, :none, nothing, F(Mooncake.__vec_to_tuple), Any[(1.0,)]), # (DI.basis fails for this, correct it!) | ||
( | ||
false, | ||
:stability_and_allocs, | ||
nothing, | ||
F(Mooncake.IntrinsicsWrappers.ctlz_int), | ||
5, | ||
), | ||
( | ||
false, | ||
:stability_and_allocs, | ||
nothing, | ||
F(Mooncake.IntrinsicsWrappers.ctpop_int), | ||
5, | ||
), | ||
( | ||
false, | ||
:stability_and_allocs, | ||
nothing, | ||
F(Mooncake.IntrinsicsWrappers.cttz_int), | ||
5, | ||
), | ||
( | ||
false, | ||
:stability_and_allocs, | ||
nothing, | ||
F(Mooncake.IntrinsicsWrappers.abs_float), | ||
5.0f0, | ||
), | ||
(false, :stability_and_allocs, nothing, F(deepcopy), 5.0), | ||
(false, :stability, nothing, F(deepcopy), randn(5)), | ||
(false, :stability_and_allocs, nothing, F(sin), 1.1), | ||
(false, :stability_and_allocs, nothing, F(sin), 1.0f1), | ||
(false, :stability_and_allocs, nothing, F(cos), 1.1), | ||
(false, :stability_and_allocs, nothing, F(cos), 1.0f1), | ||
(false, :stability_and_allocs, nothing, F(exp), 1.1), | ||
(false, :stability_and_allocs, nothing, F(exp), 1.0f1), | ||
]..., | ||
) | ||
end | ||
|
||
map([(x) -> DI.DifferentiateWith(x, DI.AutoForwardDiff())]) do F | ||
map([Float64, Float32]) do P | ||
push!( | ||
test_cases, | ||
Any[ | ||
( | ||
false, | ||
:stability_and_allocs, | ||
nothing, | ||
F(Base.FastMath.sincos), | ||
P(3.0), | ||
), | ||
(false, :none, nothing, F(Mooncake.__vec_to_tuple), [P(1.0)]), | ||
]..., | ||
) | ||
end | ||
|
||
push!( | ||
test_cases, | ||
Any[ | ||
( | ||
false, | ||
:stability_and_allocs, | ||
nothing, | ||
F(Mooncake.IntrinsicsWrappers.ctlz_int), | ||
5, | ||
), | ||
( | ||
false, | ||
:stability_and_allocs, | ||
nothing, | ||
F(Mooncake.IntrinsicsWrappers.ctpop_int), | ||
5, | ||
), | ||
( | ||
false, | ||
:stability_and_allocs, | ||
nothing, | ||
F(Mooncake.IntrinsicsWrappers.cttz_int), | ||
5, | ||
), | ||
]..., | ||
) | ||
end | ||
|
||
memory = Any[] | ||
return test_cases, memory | ||
end |
Uh oh!
There was an error while loading. Please reload this page.