Skip to content

Commit 40fd8dc

Browse files
committed
use ProjectTo
1 parent 98f893f commit 40fd8dc

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

src/rulesets/Base/arraymath.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -441,7 +441,12 @@ end
441441
frule((_, Δx, Δy), ::typeof(-), x::AbstractArray, y::AbstractArray) = x - y, Δx - Δy
442442

443443
function rrule(::typeof(-), x::AbstractArray, y::AbstractArray)
444-
subtract_pullback(dy) = (NoTangent(), dy, -dy)
444+
xproj = ProjectTo(x)
445+
yproj = ProjectTo(y)
446+
function subtract_pullback(dy_raw)
447+
dy = unthunk(dy_raw) # projs will otherwise unthunk twice
448+
(NoTangent(), xproj(dy), yproj(-dy))
449+
end
445450
return x - y, subtract_pullback
446451
end
447452

0 commit comments

Comments
 (0)