|
117 | 117 | @test all(dx .≈ 1.0)
|
118 | 118 | end
|
119 | 119 |
|
| 120 | + |
| 121 | +function setadd(out, x, y) |
| 122 | + out .= x .+ y |
| 123 | + nothing |
| 124 | +end |
| 125 | + |
| 126 | +@testset "Forward setadd" begin |
| 127 | + out = CuArray([0.0, 0.0, 0.0, 0.0]) |
| 128 | + dout = CuArray([0.0, 0.0, 0.0, 0.0]) |
| 129 | + x = CuArray([1.0, 2.0, 3.0, 4.0]) |
| 130 | + dx = CuArray([100., 300.0, 500.0, 700.0]) |
| 131 | + y = CuArray([5.0, 6.0, 7.0, 8.0]) |
| 132 | + dy = CuArray([500., 600.0, 700.0, 800.0]) |
| 133 | + res = Enzyme.autodiff(Forward, setadd, Duplicated(out, dout), Duplicated(x, dx), Duplicated(y, dy)) |
| 134 | + @test all(dout .≈ dx .+ dy) |
| 135 | +end |
| 136 | + |
| 137 | +@testset "setadd sum" begin |
| 138 | + out = CuArray([0.0, 0.0, 0.0, 0.0]) |
| 139 | + dout = CuArray([1.0, 1.0, 1.0, 1.0]) |
| 140 | + x = CuArray([1.0, 2.0, 3.0, 4.0]) |
| 141 | + dx = CuArray([0., 0.0, 0.0, 0.0]) |
| 142 | + y = CuArray([5.0, 6.0, 7.0, 8.0]) |
| 143 | + dy = CuArray([0., 0.0, 0.0, 0.0]) |
| 144 | + res = Enzyme.autodiff(Reverse, setadd, Duplicated(out, dout), Duplicated(x, dx), Duplicated(y, dy)) |
| 145 | + @test all(dx .≈ 1) |
| 146 | + @test all(dy .≈ 1) |
| 147 | +end |
120 | 148 | # TODO once reverse kernels are in
|
121 | 149 | # function togpu(x)
|
122 | 150 | # x = CuArray(x)
|
|
0 commit comments