Skip to content

Commit e543958

Browse files
Update differentiate_with.jl
1 parent 6a0d937 commit e543958

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,14 @@ function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:Number
99
# output is a vector, so we need to use the vector pullback
1010
function pullback_array!!(dy::NoRData)
1111
tx = DI.pullback(f, backend, primal_x, (fdata(y.dx),))
12-
@assert only(tx) isa rdata_type(typeof(x))
12+
@assert only(tx) isa rdata_type(typeof(primal_x))
1313
return NoRData(), only(tx)
1414
end
1515

1616
# output is a scalar, so we can use the scalar pullback
1717
function pullback_scalar!!(dy::Number)
1818
tx = DI.pullback(f, backend, primal_x, (dy,))
19-
@assert only(tx) isa rdata_type(typeof(x))
19+
@assert only(tx) isa rdata_type(typeof(primal_x))
2020
return NoRData(), only(tx)
2121
end
2222

0 commit comments

Comments
 (0)