Skip to content

Commit 971069c

Browse files
authored
Merge pull request #782 from nmheim/nh/subtract-rule
Add missing subtract rule
2 parents b3f9f9b + f44437c commit 971069c

File tree

2 files changed

+24
-0
lines changed

2 files changed

+24
-0
lines changed

src/rulesets/Base/arraymath.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -434,6 +434,21 @@ function rrule(::typeof(-), x::AbstractArray)
434434
return -x, negation_pullback
435435
end
436436

437+
#####
438+
##### Subtraction
439+
#####
440+
441+
frule((_, Δx, Δy), ::typeof(-), x::AbstractArray, y::AbstractArray) = x - y, Δx - Δy
442+
443+
function rrule(::typeof(-), x::AbstractArray, y::AbstractArray)
444+
xproj = ProjectTo(x)
445+
yproj = ProjectTo(y)
446+
function subtract_pullback(dy_raw)
447+
dy = unthunk(dy_raw) # projs will otherwise unthunk twice
448+
return (NoTangent(), xproj(dy), yproj(-dy))
449+
end
450+
return x - y, subtract_pullback
451+
end
437452

438453
#####
439454
##### Addition (Multiarg `+`)

test/rulesets/Base/arraymath.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,4 +217,13 @@
217217
@gpu test_rrule(+, randn(4, 4), randn(4, 4), randn(4, 4))
218218
@gpu test_rrule(+, randn(3), randn(3,1), randn(3,1,1))
219219
end
220+
221+
@testset "subtraction" begin
222+
# fwd
223+
@gpu test_frule(-, randn(2), randn(2))
224+
# rev
225+
@gpu test_rrule(-, randn(4, 4), randn(4, 4))
226+
@gpu test_rrule(-, randn(4), randn(ComplexF64, 4))
227+
@gpu test_rrule(-, randn(3), randn(3, 1))
228+
end
220229
end

0 commit comments

Comments
 (0)