Skip to content

Commit e0661c6

Browse files
authored
Merge pull request #127 from mcabbott/softmax
Update `NNlib.softmax` gradients
2 parents ced78ee + cd56ed4 commit e0661c6

File tree

4 files changed

+30
-7
lines changed

4 files changed

+30
-7
lines changed

.github/workflows/CI.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,9 @@ jobs:
2020
fail-fast: false
2121
matrix:
2222
version:
23-
- '1'
2423
- '1.3'
24+
- '1.6' # LTS
25+
- '1'
2526
- 'nightly'
2627
os:
2728
- ubuntu-latest

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Tracker"
22
uuid = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
3-
version = "0.2.19"
3+
version = "0.2.20"
44

55
[deps]
66
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
@@ -23,7 +23,7 @@ DiffRules = "1.4"
2323
ForwardDiff = "0.10"
2424
LogExpFunctions = "0.3"
2525
MacroTools = "0.5"
26-
NNlib = "0.6, 0.7, 0.8"
26+
NNlib = "0.7.18, 0.8" # 0.7.18 is the last version which supports Julia 1.3
2727
NaNMath = "0.3, 1"
2828
Requires = "0.5, 1.0"
2929
SpecialFunctions = "0.10, 1, 2"

src/lib/array.jl

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -471,11 +471,31 @@ import NNlib: DenseConvDims, DepthwiseConvDims, PoolDims
471471

472472
softmax(xs::TrackedArray; dims=1) = track(softmax, xs; dims=dims)
473473

474-
@grad softmax(xs; dims=1) = softmax(data(xs); dims=dims), Δ -> (nobacksies(:softmax, ∇softmax(data(Δ), data(xs); dims=dims)),)
474+
if isdefined(NNlib, :∇softmax_data) # use new form to avoid a depwarn, but only possible Julia 1.6+
475+
@eval @grad function softmax(xs; dims=1)
476+
y = softmax(data(xs); dims=dims)
477+
y, Δ -> (nobacksies(:softmax, NNlib.∇softmax_data(data(Δ), data(y); dims=dims)),)
478+
end
479+
else
480+
@eval @grad function softmax(xs; dims=1) # TODO delete this when dropping Julia 1.3 (and increase NNlib bound)
481+
y = softmax(data(xs); dims=dims)
482+
y, Δ -> (nobacksies(:softmax, ∇softmax(data(Δ), data(xs), data(y); dims=dims)),)
483+
end
484+
end
475485

476486
logsoftmax(xs::TrackedArray; dims=1) = track(logsoftmax, xs; dims=dims)
477487

478-
@grad logsoftmax(xs; dims=1) = logsoftmax(data(xs); dims=dims), Δ -> (nobacksies(:logsoftmax, ∇logsoftmax(data(Δ), data(xs); dims=dims)),)
488+
if isdefined(NNlib, :∇logsoftmax_data) # use new form to avoid a depwarn, but only possible Julia 1.6+
489+
@eval @grad function logsoftmax(xs; dims=1)
490+
y = logsoftmax(data(xs); dims=dims)
491+
y, Δ -> (nobacksies(:logsoftmax, NNlib.∇logsoftmax_data(data(Δ), data(y); dims=dims)),)
492+
end
493+
else
494+
@eval @grad function logsoftmax(xs; dims=1)
495+
y = logsoftmax(data(xs); dims=dims)
496+
y, Δ -> (nobacksies(:logsoftmax, ∇logsoftmax(data(Δ), data(xs), data(y); dims=dims)),)
497+
end
498+
end
479499

480500
depthwiseconv(x::TrackedArray, w::TrackedArray, cdims::DepthwiseConvDims; kw...) = track(depthwiseconv, x, w, cdims; kw...)
481501
depthwiseconv(x::AbstractArray, w::TrackedArray, cdims::DepthwiseConvDims; kw...) = track(depthwiseconv, x, w, cdims; kw...)

test/tracker.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@ using Random
1010

1111
gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(sin.(f(xs...))), xs...)
1212
gradtest(f, dims...) = gradtest(f, rand.(Float64, dims)...)
13-
@testset "Tracker" begin
13+
14+
@testset "Tracker" begin # overall testset, rest of the file
15+
1416
@test gradtest((x, W, b) -> σ.(W*x .+ b), 5, (2,5), 2)
1517
@test gradtest((x, W) -> σ.(W*x), 5, (2,5))
1618
@test gradtest((x, W, b) -> σ.(W*x .+ b), (5,3), (2,5), 2)
@@ -478,4 +480,4 @@ end
478480
@test size(y) == (5, 3)
479481
end
480482

481-
end #testset
483+
end # overall testset

0 commit comments

Comments
 (0)