Skip to content

Commit 6473c45

Browse files
authored
Don't load Yota at all (#166)
* Update runtests.jl * Update runtests.jl * Update Project.toml * Update runtests.jl * Update rules.jl * Update destructure.jl * test is no longer broken * Update index.md
1 parent 88b527c commit 6473c45

File tree

5 files changed

+15
-33
lines changed

5 files changed

+15
-33
lines changed

Project.toml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,13 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1414
ChainRulesCore = "1"
1515
Functors = "0.4"
1616
Statistics = "1"
17-
Yota = "0.8.2"
1817
Zygote = "0.6.40"
1918
julia = "1.6"
2019

2120
[extras]
2221
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
2322
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
24-
Yota = "cd998857-8626-517d-b929-70ad188a48f0"
2523
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2624

2725
[targets]
28-
test = ["Test", "StaticArrays", "Yota", "Zygote"]
26+
test = ["Test", "StaticArrays", "Zygote"]

docs/src/index.md

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -80,25 +80,6 @@ Zygote's "implicit" mode `gradient(() -> loss(...), Flux.params(model))` -- see
8080
[Zygote's documentation](https://fluxml.ai/Zygote.jl/dev/#Explicit-and-Implicit-Parameters-1) for more about this difference.
8181

8282

83-
## Usage with [Yota.jl](https://github.com/dfdx/Yota.jl)
84-
85-
Yota is another modern automatic differentiation package, an alternative to Zygote.
86-
87-
Its main function is `Yota.grad`, which returns the loss as well as the gradient (like `Zygote.withgradient`)
88-
but also returns a gradient component for the loss function.
89-
To extract what Optimisers.jl needs, you can write (for the Flux model above):
90-
91-
```julia
92-
using Yota
93-
94-
loss, (∇function, ∇model, ∇image) = Yota.grad(model, image) do m, x
95-
sum(m(x)
96-
end;
97-
98-
# Or else, this may save computing ∇image:
99-
loss, (_, ∇model) = grad(m -> sum(m(image)), model);
100-
```
101-
10283
## Usage with [Lux.jl](https://github.com/avik-pal/Lux.jl)
10384

10485
The main design difference of Lux from Flux is that the tree of parameters is separate from

test/destructure.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,14 +98,14 @@ end
9898
sum(gradient(m -> sum(destructure(m)[1])^3, (v, [4,5,6.0]))[1][1])
9999
end[1] == [378, 378, 378]
100100

101-
@test_broken gradient([1,2,3.0]) do v
101+
VERSION >= v"1.10" && @test gradient([1,2,3.0]) do v
102102
sum(abs2, gradient(m -> sum(abs2, destructure(m)[1]), (x = v, y = sin, z = [4,5,6.0]))[1][1])
103103
end[1] [8,16,24]
104104
# Zygote error in (::typeof(∂(canonicalize)))(Δ::NamedTuple{(:backing,), Tuple{NamedTuple{(:x, :y, :z)
105105
# Diffractor error in perform_optic_transform
106106
end
107107

108-
VERSION < v"1.9-" && @testset "using Yota" begin
108+
false && @testset "using Yota" begin
109109
@test Yota_gradient(m -> destructure(m)[1][1], m1)[1] == [1,0,0]
110110
@test Yota_gradient(m -> destructure(m)[1][2], m2)[1] == ([0,1,0], [0,0,0])
111111
@test Yota_gradient(m -> destructure(m)[1][3], (m1, m1))[1] == ([0,0,1], nothing)
@@ -175,7 +175,7 @@ end
175175
# Zygote.@adjoint Tangent{T,B}(x::NamedTuple) where {T,B<:NamedTuple} = Tangent{T,B}(x), dx -> (dx,)
176176
end
177177

178-
VERSION < v"1.9-" && @testset "using Yota" begin
178+
false && @testset "using Yota" begin
179179
re1 = destructure(m1)[2]
180180
@test Yota_gradient(x -> re1(x)[1], rand(3))[1] == [1,0,0]
181181
re2 = destructure(m2)[2]

test/rules.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ end
230230
end
231231
end
232232

233-
VERSION < v"1.9-" && @testset "using Yota" begin
233+
false && @testset "using Yota" begin
234234
@testset "$(name(o))" for o in RULES
235235
w′ = (abc == rand(3, 3), β = rand(3, 3), γ = rand(3)), d == rand(3), ε = eps))
236236
w = (abc == 5rand(3, 3), β = rand(3, 3), γ = rand(3)), d == rand(3), ε = eps))
@@ -266,4 +266,4 @@ end
266266

267267
tree, x4 = Optimisers.update(tree, x3, g4)
268268
@test x4 x3
269-
end
269+
end

test/runtests.jl

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
using Optimisers
2-
using ChainRulesCore, Functors, StaticArrays, Zygote, Yota
2+
using ChainRulesCore, Functors, StaticArrays, Zygote
33
using LinearAlgebra, Statistics, Test, Random
44
using Optimisers: @.., @lazy
55
using Base.Broadcast: broadcasted, instantiate, Broadcasted
@@ -38,12 +38,15 @@ function Optimisers.apply!(o::BiRule, state, x, dx, dx2)
3838
return state, dx
3939
end
4040

41-
# Make Yota's output look like Zygote's:
41+
# if VERSION < v"1.9-"
42+
# using Yota
43+
# end
44+
# # Make Yota's output look like Zygote's:
4245

43-
Yota_gradient(f, xs...) = map(y2z, Base.tail(Yota.grad(f, xs...)[2]))
44-
y2z(::AbstractZero) = nothing # we don't care about different flavours of zero
45-
y2z(t::Tangent) = map(y2z, ChainRulesCore.backing(canonicalize(t))) # namedtuples!
46-
y2z(x) = x
46+
# Yota_gradient(f, xs...) = map(y2z, Base.tail(Yota.grad(f, xs...)[2]))
47+
# y2z(::AbstractZero) = nothing # we don't care about different flavours of zero
48+
# y2z(t::Tangent) = map(y2z, ChainRulesCore.backing(canonicalize(t))) # namedtuples!
49+
# y2z(x) = x
4750

4851
@testset verbose=true "Optimisers.jl" begin
4952
@testset verbose=true "Features" begin

0 commit comments

Comments
 (0)