Skip to content

Commit 8654721

Browse files
authored
Restore some support for Tracker.jl (#2387)
* restore and test some support for Tracker.jl * bump Tracker compat * Update Project.toml
1 parent c4a0ee4 commit 8654721

File tree

5 files changed

+50
-4
lines changed

5 files changed

+50
-4
lines changed

Project.toml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Flux"
22
uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c"
3-
version = "0.14.14"
3+
version = "0.14.15"
44

55
[deps]
66
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
@@ -54,6 +54,7 @@ ProgressLogging = "0.1"
5454
Reexport = "1.0"
5555
SpecialFunctions = "2.1.2"
5656
Statistics = "1"
57+
Tracker = "0.2.33"
5758
Zygote = "0.6.67"
5859
cuDNN = "1"
5960
julia = "1.9"
@@ -72,9 +73,11 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
7273
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
7374
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
7475
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
76+
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
7577
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
7678

7779
[targets]
7880
test = ["Test", "Documenter", "IterTools", "LinearAlgebra", "FillArrays",
7981
"ComponentArrays", "BSON", "Pkg", "CUDA", "cuDNN", "Metal", "AMDGPU",
80-
"Enzyme", "FiniteDifferences"]
82+
"Enzyme", "FiniteDifferences", "Tracker"]
83+

src/layers/stateless.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,8 @@ end
8181
_match_eltype(layer, ::Type, x::AbstractArray) = x
8282

8383
# 2-arg method, for common layers with layer.weight
84-
_match_eltype(layer, x) = _match_eltype(layer, eltype(layer.weight), x)
84+
# NB using _eltype gets Float64 from Tracker.TrackedArray{Float64}, not TrackedReal
85+
_match_eltype(layer, x) = _match_eltype(layer, _eltype(layer.weight), x)
8586

8687
# Trivial rule:
8788
function ChainRulesCore.rrule(::typeof(_match_eltype), layer, ::Type{T}, x::AbstractArray) where {T}

src/utils.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -501,9 +501,12 @@ function create_bias(weights::AbstractArray, bias::Bool, dims::Integer...)
501501
end
502502
function create_bias(weights::AbstractArray, bias::AbstractArray, dims::Integer...)
503503
size(bias) == dims || throw(DimensionMismatch("expected bias of size $(dims), got size $(size(bias))"))
504-
convert(AbstractArray{eltype(weights)}, bias)
504+
convert(AbstractArray{_eltype(weights)}, bias)
505505
end
506506

507+
# This avoids the issue that Tracker.TrackedArray{Float64} declares eltype() = TrackedReal
508+
_eltype(::AbstractArray{T}) where T = T
509+
507510

508511
# Other
509512

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ Random.seed!(0)
2828
@testset "Optimise / Train" begin
2929
include("optimise.jl")
3030
include("train.jl")
31+
include("tracker.jl")
3132
end
3233

3334
@testset "Data" begin

test/tracker.jl

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
using Tracker: withgradient
2+
using Zygote: gradient
3+
using Functors: fmapstructure
4+
using Flux
5+
6+
@testset "Tracker.jl" begin
7+
@testset "some simple models" begin
8+
m1 = Dense(ones32(2,3), fill(0.1f0,2), abs2)
9+
x1 = Float32[1,2,3]
10+
(_, v1), g1 = withgradient(m1, x1) do m, x
11+
y1 = m(x)
12+
sum(abs2, y1 .- [4, 5]), y1
13+
end
14+
@test v1 m1(x1)
15+
g1z = gradient(m1, x1) do m, x
16+
sum(abs2, m(x) .- [4, 5])
17+
end
18+
@test g1[1].weight g1z[1].weight
19+
@test g1[1].bias g1z[1].bias
20+
21+
m2 = Chain(Conv((2,2), 3 => 1, relu), Flux.flatten, Dense(20 => 1, tanh), only)
22+
x2 = randn32(5,6,3,1)
23+
v2, g2 = withgradient(m -> m(x2), m2)
24+
g2z = gradient(m -> m(x2), m2)
25+
@test g2[1].layers[1].weight g2z[1].layers[1].weight
26+
@test g2[1].layers[1].bias g2z[1].layers[1].bias
27+
@test g2[1].layers[3].weight g2z[1].layers[3].weight
28+
end
29+
30+
@testset "Dropout" begin
31+
g1z = gradient(sumDropout(0.5), ones(1000))
32+
v1, g1 = withgradient(sumDropout(0.5), ones(1000))
33+
@test 800<v1<1200
34+
@test sum(g1[1]) v1
35+
@test 400 < count(iszero, g1[1]) < 600
36+
end
37+
end
38+

0 commit comments

Comments
 (0)