@@ -6,6 +6,7 @@ function frule((_, Δx), ::typeof(norm), x)
6
6
y = norm (x)
7
7
return y, _norm2_forward (x, Δx, norm (x))
8
8
end
9
+
9
10
function frule ((_, Δx), :: typeof (norm), x:: Number , p:: Real )
10
11
y = norm (x, p)
11
12
∂y = if iszero (Δx) || iszero (p)
20
21
function rrule (:: typeof (norm), x:: AbstractArray{<:Number} , p:: Real )
21
22
y = LinearAlgebra. norm (x, p)
22
23
function norm_pullback_p (Δy)
23
- # ∂x = InplaceableThunk(
24
+ ∂x = InplaceableThunk (
24
25
# out-of-place versions
25
- ∂x = @thunk (if isempty (x) || p == 0
26
+ @thunk (if isempty (x) || p == 0
26
27
zero .(x) .* (zero (y) * zero (real (Δy)))
27
28
elseif p == 2
28
29
_norm2_back (x, y, Δy)
@@ -35,48 +36,50 @@ function rrule(::typeof(norm), x::AbstractArray{<:Number}, p::Real)
35
36
else
36
37
_normp_back_x (x, p, y, Δy)
37
38
end )
38
- # , # in-place versions -- can be fixed when actually useful?
39
- # dx -> if isempty(x) || p == 0
40
- # dx
41
- # elseif p == 2
42
- # _norm2_back!(dx, x, y, Δy)
43
- # elseif p == 1
44
- # _norm1_back!(dx, x, y, Δy)
45
- # elseif p == Inf
46
- # dx .+= _normInf_back(x, y, Δy) # not really in-place! could perhaps be improved
47
- # elseif p == -Inf
48
- # dx .+= _normInf_back(x, y, Δy)
49
- # else
50
- # dx .+= _normp_back_x(x, p, y, Δy)
51
- # end
52
- # )
39
+ , # in-place versions -- can be fixed when actually useful?
40
+ dx -> if isempty (x) || p == 0
41
+ dx
42
+ elseif p == 2
43
+ _norm2_back! (dx, x, y, Δy)
44
+ elseif p == 1
45
+ _norm1_back! (dx, x, y, Δy)
46
+ elseif p == Inf
47
+ dx .+ = _normInf_back (x, y, Δy) # not really in-place! could perhaps be improved
48
+ elseif p == - Inf
49
+ dx .+ = _normInf_back (x, y, Δy)
50
+ else
51
+ dx .+ = _normp_back_x (x, p, y, Δy)
52
+ end
53
+ )
53
54
∂p = @thunk _normp_back_p (x, p, y, Δy)
54
55
return (NO_FIELDS, ∂x, ∂p)
55
56
end
56
57
norm_pullback_p (:: Zero ) = (NO_FIELDS, Zero (), Zero ())
57
58
return y, norm_pullback_p
58
59
end
60
+
59
61
function rrule (:: typeof (norm), x:: AbstractArray{<:Number} )
60
62
y = LinearAlgebra. norm (x)
61
63
function norm_pullback_2 (Δy)
62
- # ∂x = InplaceableThunk(
63
- ∂x = @thunk (if isempty (x)
64
+ ∂x = InplaceableThunk (
65
+ @thunk (if isempty (x)
64
66
zero .(x) .* (zero (y) * zero (real (Δy)))
65
67
else
66
68
_norm2_back (x, y, Δy)
67
69
end )
68
- # ,
69
- # dx -> if isempty(x)
70
- # dx
71
- # else
72
- # _norm2_back!(dx, x, y, Δy)
73
- # end
74
- # )
70
+ ,
71
+ dx -> if isempty (x)
72
+ dx
73
+ else
74
+ _norm2_back! (dx, x, y, Δy)
75
+ end
76
+ )
75
77
return (NO_FIELDS, ∂x)
76
78
end
77
79
norm_pullback_2 (:: Zero ) = (NO_FIELDS, Zero ())
78
80
return y, norm_pullback_2
79
81
end
82
+
80
83
function rrule (:: typeof (norm), x:: LinearAlgebra.AdjOrTransAbsVec{<:Number} , p:: Real )
81
84
y, inner_pullback = rrule (norm, parent (x), p)
82
85
function norm_pullback (Δy)
@@ -87,6 +90,7 @@ function rrule(::typeof(norm), x::LinearAlgebra.AdjOrTransAbsVec{<:Number}, p::R
87
90
end
88
91
return y, norm_pullback
89
92
end
93
+
90
94
function rrule (:: typeof (norm), x:: Number , p:: Real )
91
95
y = norm (x, p)
92
96
function norm_pullback (Δy)
@@ -126,6 +130,7 @@ function _normp_back_x(x, p, y, Δy)
126
130
end
127
131
return ∂x
128
132
end
133
+
129
134
function _normp_back_x (x:: WithSomeZeros , p, y, Δy) # Diagonal, UpperTriangular, etc.
130
135
c = real (Δy) / y
131
136
∂x_data = map (parent (x)) do xi
@@ -261,6 +266,7 @@ function rrule(::typeof(normalize), x::AbstractVector{<:Number}, p::Real)
261
266
normalize_pullback (:: Zero ) = (NO_FIELDS, Zero (), Zero ())
262
267
return y, normalize_pullback
263
268
end
269
+
264
270
function rrule (:: typeof (normalize), x:: AbstractVector{<:Number} )
265
271
nrm = LinearAlgebra. norm2 (x)
266
272
Ty = typeof (first (x) / nrm)
0 commit comments