Skip to content

Commit 3d381ab

Browse files
authored
Merge pull request #626 from Jutho/patch-1
restrict strided array multiplication `rrule`
2 parents d0bcfc5 + 542abbf commit 3d381ab

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

src/rulesets/Base/arraymath.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,9 @@ end
4646
# https://github.com/JuliaDiff/ChainRulesCore.jl/issues/411
4747
function rrule(
4848
::typeof(*),
49-
A::StridedMatrix{<:CommutativeMulNumber},
50-
B::StridedVecOrMat{<:CommutativeMulNumber},
51-
)
49+
A::StridedMatrix{T},
50+
B::StridedVecOrMat{T},
51+
) where {T<:CommutativeMulNumber}
5252
function times_pullback(ȳ)
5353
= unthunk(ȳ)
5454
dA = InplaceableThunk(

0 commit comments

Comments
 (0)