Skip to content

Commit 36da036

Browse files
assertion for array inputs
1 parent c63c956 commit 36da036

File tree

2 files changed

+5
-3
lines changed

2 files changed

+5
-3
lines changed

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,15 +33,15 @@ function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:Abstra
3333
# output is a vector, so we need to use the vector pullback
3434
function pullback_array!!(dy::NoRData)
3535
tx = DI.pullback(f, backend, primal_x, (fdata(y.dx),))
36-
@assert only(tx) isa rdata_type(typeof(primal_x))
36+
@assert first(only(tx)) isa rdata_type(typeof(first(primal_x)))
3737
fdata_arg .+= only(tx)
3838
return NoRData(), dy
3939
end
4040

4141
# output is a scalar, so we can use the scalar pullback
4242
function pullback_scalar!!(dy::Number)
4343
tx = DI.pullback(f, backend, primal_x, (dy,))
44-
@assert only(tx) isa rdata_type(typeof(primal_x))
44+
@assert first(only(tx)) isa rdata_type(typeof(first(primal_x)))
4545
fdata_arg .+= only(tx)
4646
return NoRData(), NoRData()
4747
end

DifferentiationInterface/test/Back/DifferentiateWith/test.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@ LOGGING = get(ENV, "CI", "false") == "false"
1313

1414
function differentiatewith_scenarios()
1515
bad_scens = # these closurified scenarios have mutation and type constraints
16-
filter(default_scenarios(; include_normal=false, include_closurified=true)) do scen
16+
filter(
17+
DIT.default_scenarios(; include_normal=false, include_closurified=true)
18+
) do scen
1719
DIT.function_place(scen) == :out
1820
end
1921
good_scens = map(bad_scens) do scen

0 commit comments

Comments
 (0)