12
12
),
13
13
T in (Float64, ComplexF64),
14
14
sz in [(3 ,), (3 , 3 ), (3 , 2 , 1 )]
15
- println (" starting unexported fnorm=$fnorm , T=$T , sz=$sz " )
16
15
17
16
x = randn (T, sz)
18
17
# finite differences is unstable if maxabs (minabs) values are not well
@@ -44,47 +43,15 @@ println("starting unexported fnorm=$fnorm, T=$T, sz=$sz")
44
43
@test rrule (fnorm, x)[2 ](Zero ())[2 ] isa Zero
45
44
end
46
45
ndims (x) > 1 && @testset " non-strided" begin
47
- println (" ... non-strided" )
48
46
xp = if x isa Matrix
49
47
view (x, [1 ,2 ,3 ], 1 : 3 )
50
48
elseif x isa Array{T,3 }
51
49
PermutedDimsArray (x, (1 ,2 ,3 ))
52
50
end
53
51
@test ! (xp isa StridedArray)
54
- # y = fnorm(x)
55
- # # ẋ = rand(T, size(xp)) # rand_tangent(xp)
56
- # x̄ = rand(T, size(xp)) # rand_tangent(xp)
57
- # ȳ = rand_tangent(y)
58
- # # frule_test(fnorm, (xp, ẋ))
59
- # rrule_test(fnorm, ȳ, (xp, x̄)) # old notation, gives a depwarn
60
- #=
61
- ┌ Warning: `rrule_test(f, ȳ, inputs::Tuple{Any, Any}...; kwargs...)` is deprecated, use `test_rrule(f, (x ⊢ dx for (x, dx) = inputs)...; output_tangent = ȳ, kwargs...)` instead.
62
- │ caller = macro expansion at norm.jl:57 [inlined]
63
- └ @ Core ~/.julia/dev/ChainRules/test/rulesets/LinearAlgebra/norm.jl:57
64
- =#
65
- # @show typeof(xp)
66
- # test_rrule(fnorm, xp) # new notation, gives a spectacular failure:
67
- #=
68
- typeof(xp) = SubArray{Float64, 2, Matrix{Float64}, Tuple{Vector{Int64}, UnitRange{Int64}}, false}
69
- test_rrule: norm1 at ([0.2972879845354616 -0.01044524463737564 2.2950878238373105; 0.3823959677906078 -0.839026854388764 -2.2670863488005306; -0.5976344767282311 0.31111133849833383 0.5299655761667461],): Error During Test at /Users/me/.julia/packages/ChainRulesTestUtils/bDd51/src/testers.jl:168
70
- Got exception outside of a @test
71
- MethodError: no method matching +(::Composite{SubArray{Float64, 2, Matrix{Float64}, Tuple{Vector{Int64}, UnitRange{Int64}}, false}, NamedTuple{(:parent, :indices, :offset1, :stride1), Tuple{Matrix{Float64}, Composite{Tuple{Vector{Int64}, UnitRange{Int64}}, Tuple{Vector{DoesNotExist}, Composite{UnitRange{Int64}, NamedTuple{(:start, :stop), Tuple{DoesNotExist, DoesNotExist}}}}}, DoesNotExist, DoesNotExist}}}, ::Matrix{Float64})
72
- Closest candidates are:
73
- +(::Any, ::Any, ::Any, ::Any...) at operators.jl:560
74
- +(::Composite{P, T} where T, ::Composite{P, T} where T) where P at /Users/me/.julia/packages/ChainRulesCore/1qau5/src/differential_arithmetic.jl:167
75
- +(::Composite, ::AbstractThunk) at /Users/me/.julia/packages/ChainRulesCore/1qau5/src/differential_arithmetic.jl:161
76
- ...
77
- Stacktrace:
78
- [1] +(a::Composite{SubArray{Float64, 2, Matrix{Float64}, Tuple{Vector{Int64}, UnitRange{Int64}}, false}, NamedTuple{(:parent, :indices, :offset1, :stride1), Tuple{Matrix{Float64}, Composite{Tuple{Vector{Int64}, UnitRange{Int64}}, Tuple{Vector{DoesNotExist}, Composite{UnitRange{Int64}, NamedTuple{(:start, :stop), Tuple{DoesNotExist, DoesNotExist}}}}}, DoesNotExist, DoesNotExist}}}, b::InplaceableThunk{Thunk{ChainRules.var"#1798#1801"{Float64, SubArray{Float64, 2, Matrix{Float64}, Tuple{Vector{Int64}, UnitRange{Int64}}, false}, Float64}}, ChainRules.var"#1799#1802"{Float64, SubArray{Float64, 2, Matrix{Float64}, Tuple{Vector{Int64}, UnitRange{Int64}}, false}, Float64}})
79
- @ ChainRulesCore ~/.julia/packages/ChainRulesCore/1qau5/src/differential_arithmetic.jl:161
80
- [2] add!!(x::Composite{SubArray{Float64, 2, Matrix{Float64}, Tuple{Vector{Int64}, UnitRange{Int64}}, false}, NamedTuple{(:parent, :indices, :offset1, :stride1), Tuple{Matrix{Float64}, Composite{Tuple{Vector{Int64}, UnitRange{Int64}}, Tuple{Vector{DoesNotExist}, Composite{UnitRange{Int64}, NamedTuple{(:start, :stop), Tuple{DoesNotExist, DoesNotExist}}}}}, DoesNotExist, DoesNotExist}}}, t::InplaceableThunk{Thunk{ChainRules.var"#1798#1801"{Float64, SubArray{Float64, 2, Matrix{Float64}, Tuple{Vector{Int64}, UnitRange{Int64}}, false}, Float64}}, ChainRules.var"#1799#1802"{Float64, SubArray{Float64, 2, Matrix{Float64}, Tuple{Vector{Int64}, UnitRange{Int64}}, false}, Float64}})
81
- @ ChainRulesCore ~/.julia/packages/ChainRulesCore/1qau5/src/accumulation.jl:23
82
- =#
83
- test_rrule (fnorm, xp ⊢ rand (T, size (xp))) # ok, this passes!
84
-
52
+ test_rrule (fnorm, xp ⊢ rand (T, size (xp)))
85
53
end
86
54
T == Float64 && ndims (x) == 1 && @testset " Integer input" begin
87
- println (" ... integer" )
88
55
x = [1 ,2 ,3 ]
89
56
int_fwd, int_back = rrule (fnorm, x)
90
57
float_fwd, float_back = rrule (fnorm, float (x))
@@ -93,13 +60,12 @@ println("... integer")
93
60
end
94
61
end
95
62
96
- # Next test norm(x , p=2) -- two methods
63
+ # Next test norm(A , p=2) -- two methods
97
64
# =====================================
98
65
99
66
@testset " norm(x::Array{$T ,$(length (sz)) })" for
100
67
T in (Float64, ComplexF64),
101
68
sz in [(0 ,), (3 ,), (3 , 3 ), (3 , 2 , 1 )]
102
- println (" starting exported norm T=$T , sz=$sz " )
103
69
104
70
x = randn (T, sz)
105
71
@@ -121,19 +87,12 @@ println("starting exported norm T=$T, sz=$sz")
121
87
@test rrule (norm, x)[2 ](Zero ())[2 ] isa Zero
122
88
end
123
89
ndims (x) > 1 && @testset " non-strided" begin
124
- println (" ... non-strided'" )
125
90
xp = if x isa Matrix
126
91
view (x, [1 ,2 ,3 ], 1 : 3 )
127
92
elseif x isa Array{T,3 }
128
93
PermutedDimsArray (x, (1 ,2 ,3 ))
129
94
end
130
95
@test ! (xp isa StridedArray)
131
- # y = norm(x)
132
- # ẋ = rand(T, size(xp)) # rand_tangent(xp)
133
- # x̄ = rand(T, size(xp)) # rand_tangent(xp)
134
- # ȳ = rand_tangent(y)
135
- # frule_test(norm, (xp, ẋ))
136
- # rrule_test(norm, ȳ, (xp, x̄))
137
96
test_frule (norm, xp ⊢ rand (T, size (xp)))
138
97
test_rrule (norm, xp ⊢ rand (T, size (xp))) # rand_tangent does not work here
139
98
end
@@ -143,11 +102,10 @@ println("... non-strided'")
143
102
p in (1.0 , 2.0 , Inf , - Inf , 2.5 ),
144
103
T in (Float64, ComplexF64),
145
104
sz in (fnorm === norm ? [(0 ,), (3 ,), (3 , 3 ), (3 , 2 , 1 )] : [(3 ,), (3 , 3 ), (3 , 2 , 1 )])
146
- println (" starting p-norm p=$p , T=$T , sz=$sz " )
147
105
148
106
x = randn (T, sz)
149
107
# finite differences is unstable if maxabs (minabs) values are not well
150
- # separated from other values
108
+ # separated from other values (same as above)
151
109
if p == Inf
152
110
if ! isempty (x)
153
111
x[end ] = 1000 rand (T)
@@ -183,7 +141,6 @@ println("starting p-norm p=$p, T=$T, sz=$sz")
183
141
@testset " norm($fdual (::Vector{$T }), 2.5)" for
184
142
T in (Float64, ComplexF64),
185
143
fdual in (adjoint, transpose)
186
- println (" starting $fdual norm T=$T " )
187
144
188
145
x = fdual (randn (T, 3 ))
189
146
p = 2.5
@@ -198,15 +155,13 @@ println("starting $fdual norm T=$T")
198
155
199
156
@testset " norm(x::$T , p)" for T in (Float64, ComplexF64)
200
157
@testset " p = $p " for p in (- 1.0 , 2.0 , 2.5 )
201
- println (" starting scalar p-norm tests, p=$p , T=$T " )
202
158
test_frule (norm, randn (T), p)
203
159
test_rrule (norm, randn (T), p)
204
160
205
161
_, back = rrule (norm, randn (T), p)
206
162
@test back (Zero ()) == (NO_FIELDS, Zero (), Zero ())
207
163
end
208
164
@testset " p = 0" begin
209
- println (" starting 0-norm tests, T=$T " )
210
165
p = 0.0
211
166
x = randn (T)
212
167
y = norm (x, p)
0 commit comments