@@ -16,54 +16,37 @@ Notice:
16
16
17
17
# we propagate `NotImplemented` (e.g., in `@scalar_rule`)
18
18
# this requires the following definitions (see also #337)
19
- Base.:+ (x:: NotImplemented , :: ZeroTangent ) = x
20
- Base.:+ (:: ZeroTangent , x:: NotImplemented ) = x
21
19
Base.:+ (x:: NotImplemented , :: NotImplemented ) = x
22
- Base.:* (:: NotImplemented , :: ZeroTangent ) = ZeroTangent ()
23
- Base.:* (:: ZeroTangent , :: NotImplemented ) = ZeroTangent ()
24
- for T in (:NoTangent , :AbstractThunk , :Tangent , :Any )
20
+ Base.:* (x:: NotImplemented , :: NotImplemented ) = x
21
+ LinearAlgebra. dot (x:: NotImplemented , :: NotImplemented ) = x
22
+ # `NotImplemented` always "wins" +
23
+ for T in (:ZeroTangent , :NoTangent , :AbstractThunk , :Tangent , :Any )
25
24
@eval Base.:+ (x:: NotImplemented , :: $T ) = x
26
25
@eval Base.:+ (:: $T , x:: NotImplemented ) = x
26
+ end
27
+ # `NotImplemented` "loses" * and dot against NoTangent and ZeroTangent
28
+ # this can be used to ignore partial derivatives that are not implemented
29
+ for T in (:ZeroTangent , :NoTangent )
30
+ @eval Base.:* (:: NotImplemented , :: $T ) = $ T ()
31
+ @eval Base.:* (:: $T , :: NotImplemented ) = $ T ()
32
+ @eval LinearAlgebra. dot (:: NotImplemented , :: $T ) = $ T ()
33
+ @eval LinearAlgebra. dot (:: $T , :: NotImplemented ) = $ T ()
34
+ end
35
+ # `NotImplemented` "wins" * and dot for other types
36
+ for T in (:AbstractThunk , :Tangent , :Any )
27
37
@eval Base.:* (x:: NotImplemented , :: $T ) = x
38
+ @eval Base.:* (:: $T , x:: NotImplemented ) = x
39
+ @eval LinearAlgebra. dot (x:: NotImplemented , :: $T ) = x
40
+ @eval LinearAlgebra. dot (:: $T , x:: NotImplemented ) = x
28
41
end
29
- Base. muladd (x:: NotImplemented , y, z) = x
30
- Base. muladd (:: NotImplemented , :: ZeroTangent , z) = z
31
- Base. muladd (x:: NotImplemented , y, :: ZeroTangent ) = x
32
- Base. muladd (:: NotImplemented , :: ZeroTangent , :: ZeroTangent ) = ZeroTangent ()
33
- Base. muladd (x, y:: NotImplemented , z) = y
34
- Base. muladd (:: ZeroTangent , :: NotImplemented , z) = z
35
- Base. muladd (x, y:: NotImplemented , :: ZeroTangent ) = y
36
- Base. muladd (:: ZeroTangent , :: NotImplemented , :: ZeroTangent ) = ZeroTangent ()
37
- Base. muladd (x, y, z:: NotImplemented ) = z
38
- Base. muladd (:: ZeroTangent , y, z:: NotImplemented ) = z
39
- Base. muladd (x, :: ZeroTangent , z:: NotImplemented ) = z
40
- Base. muladd (:: ZeroTangent , :: ZeroTangent , z:: NotImplemented ) = z
41
- Base. muladd (x:: NotImplemented , :: NotImplemented , z) = x
42
- Base. muladd (x:: NotImplemented , :: NotImplemented , :: ZeroTangent ) = x
43
- Base. muladd (x:: NotImplemented , y, :: NotImplemented ) = x
44
- Base. muladd (:: NotImplemented , :: ZeroTangent , z:: NotImplemented ) = z
45
- Base. muladd (x, y:: NotImplemented , :: NotImplemented ) = y
46
- Base. muladd (:: ZeroTangent , :: NotImplemented , z:: NotImplemented ) = z
47
- Base. muladd (x:: NotImplemented , :: NotImplemented , :: NotImplemented ) = x
48
- LinearAlgebra. dot (:: NotImplemented , :: ZeroTangent ) = ZeroTangent ()
49
- LinearAlgebra. dot (:: ZeroTangent , :: NotImplemented ) = ZeroTangent ()
50
-
51
- # other common operations throw an exception
52
- Base.:+ (x:: NotImplemented ) = throw (NotImplementedException (x))
42
+
43
+ # subtraction throws an exception: in AD we add tangents but do not subtract them
44
+ # subtraction happens eg. in gradient descent which can't be performed with `NotImplemented`
53
45
Base.:- (x:: NotImplemented ) = throw (NotImplementedException (x))
54
- Base.:- (x:: NotImplemented , :: ZeroTangent ) = throw (NotImplementedException (x))
55
- Base.:- (:: ZeroTangent , x:: NotImplemented ) = throw (NotImplementedException (x))
56
46
Base.:- (x:: NotImplemented , :: NotImplemented ) = throw (NotImplementedException (x))
57
- Base.:* (x:: NotImplemented , :: NotImplemented ) = throw (NotImplementedException (x))
58
- function LinearAlgebra. dot (x:: NotImplemented , :: NotImplemented )
59
- return throw (NotImplementedException (x))
60
- end
61
- for T in (:NoTangent , :AbstractThunk , :Tangent , :Any )
47
+ for T in (:ZeroTangent , :NoTangent , :AbstractThunk , :Tangent , :Any )
62
48
@eval Base.:- (x:: NotImplemented , :: $T ) = throw (NotImplementedException (x))
63
49
@eval Base.:- (:: $T , x:: NotImplemented ) = throw (NotImplementedException (x))
64
- @eval Base.:* (:: $T , x:: NotImplemented ) = throw (NotImplementedException (x))
65
- @eval LinearAlgebra. dot (x:: NotImplemented , :: $T ) = throw (NotImplementedException (x))
66
- @eval LinearAlgebra. dot (:: $T , x:: NotImplemented ) = throw (NotImplementedException (x))
67
50
end
68
51
69
52
Base.:+ (:: NoTangent , :: NoTangent ) = NoTangent ()
0 commit comments