Skip to content

Commit b3e3c75

Browse files
committed
Add Tracker support
1 parent 02dcfa0 commit b3e3c75

File tree

3 files changed

+5
-0
lines changed

3 files changed

+5
-0
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
2424
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
2525
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
2626
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
27+
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
2728
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2829

2930
[compat]

src/optimise/Optimise.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ using Flux
44
using MacroTools: @forward
55
import Zygote
66
import Zygote: Params, gradient
7+
import Tracker
78
using AbstractDifferentiation
89
import Optimisers
910
import Optimisers: update, update!

src/optimise/gradients.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,6 @@ AD.@primitive pullback_function(ad::ZygoteExplicitBackend, f, xs...) =
2121
# this is a hack to get around
2222
# https://github.com/JuliaDiff/AbstractDifferentiation.jl/issues/63#issuecomment-1225959150
2323
AD.gradient(::ZygoteExplicitBackend, f, xs...) = Zygote.gradient(f, xs...)
24+
25+
# this is to work around AD.TrackerBackend only supporting vectors of params
26+
AD.gradient(::AD.TrackerBackend, f, xs...) = Tracker.withgradient(f, xs...).grad

0 commit comments

Comments
 (0)