Skip to content

Commit d9aaaf7

Browse files
Merge pull request #271 from mcabbott/batch3
Fix stride & size inference of 'T'/'N' in `batched_mul`
2 parents e1a5945 + 5df84c4 commit d9aaaf7

File tree

3 files changed

+40
-15
lines changed

3 files changed

+40
-15
lines changed

src/batched/batchedadjtrans.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,8 @@ Base.axes(m::BatchedAdjOrTrans) = (axes(m.parent, 2), axes(m.parent, 1), axes(m.
6464
Base.IndexStyle(::Type{<:BatchedAdjOrTrans}) = IndexCartesian()
6565
Base.@propagate_inbounds Base.getindex(m::BatchedTranspose, i::Int, j::Int, k::Int) = getindex(m.parent, j, i, k)
6666
Base.@propagate_inbounds Base.getindex(m::BatchedAdjoint, i::Int, j::Int, k::Int) = adjoint(getindex(m.parent, j, i, k))
67-
Base.@propagate_inbounds Base.setindex!(m::BatchedAdjOrTrans, v, i::Int, j::Int, k::Int) = setindex!(m.parent, v, j, i, k)
67+
Base.@propagate_inbounds Base.setindex!(m::BatchedTranspose, v, i::Int, j::Int, k::Int) = setindex!(m.parent, v, j, i, k)
68+
Base.@propagate_inbounds Base.setindex!(m::BatchedAdjoint, v, i::Int, j::Int, k::Int) = setindex!(m.parent, adjoint(v), j, i, k)
6869

6970
Base.similar(A::BatchedAdjOrTrans, T::Type, dims::Dims) = similar(A.parent, T, dims)
7071
Base.similar(A::BatchedAdjOrTrans, dims::Dims) = similar(A.parent, dims)

src/batched/batchedmul.jl

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -224,34 +224,27 @@ function _batched_try_gemm!(::Type{DT}, C, A, B, α::Number, β::Number) where {
224224
alpha, beta = promote(α, β, zero(T))
225225
alpha isa T && beta isa T || return batched_mul_generic!(C, A, B, α, β)
226226

227-
are_strided(C, _unbatch(A), _unbatch(B)) || return batched_mul_generic!(C, A, B, α, β)
228-
229-
if Base.stride(C,1) == 1
230-
elseif Base.stride(C,2) == 1
231-
@debug "transforming C = A * B into C' = B' * A'" size(C) strides(C)
232-
return batched_mul!(batched_adjoint(C), batched_adjoint(B), batched_adjoint(A), α, β)
233-
else
234-
return batched_mul_generic!(C, A, B, α, β)
235-
end
227+
are_strided(_unbatch(A), _unbatch(B)) || return batched_mul_generic!(C, A, B, α, β)
228+
C isa StridedArray || return batched_mul_generic!(C, A, B, α, β)
236229

237230
blasA, transA = if A isa BatchedAdjoint && T <: Complex
238231
Base.stride(parent(A),1) == 1 || return batched_mul_generic!(C, A, B, α, β)
239232
parent(A), 'C'
233+
elseif Base.stride(A,2) == 1 && size(A,1) > 1
234+
batched_transpose(A), 'T'
240235
elseif Base.stride(A,1) == 1
241236
A, 'N'
242-
elseif Base.stride(A,2) == 1
243-
batched_transpose(A), 'T'
244237
else
245238
return batched_mul_generic!(C, A, B, α, β)
246239
end
247240

248241
blasB, transB = if B isa BatchedAdjoint && T <: Complex
249242
Base.stride(parent(B),1) == 1 || return batched_mul_generic!(C, A, B, α, β)
250243
parent(B), 'C'
244+
elseif Base.stride(B,2) == 1 && size(B,1) > 1
245+
batched_transpose(B), 'T'
251246
elseif Base.stride(B,1) == 1
252247
B, 'N'
253-
elseif Base.stride(B,2) == 1
254-
batched_transpose(B), 'T'
255248
else
256249
return batched_mul_generic!(C, A, B, α, β)
257250
end

test/batchedmul.jl

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
using NNlib, Test, LinearAlgebra
22
using NNlib: storage_type, storage_typejoin, is_strided,
3-
batched_mul!, _unbatch, _copy_if_faster,
3+
batched_mul!, batched_mul_generic!, _unbatch, _copy_if_faster,
44
BatchedAdjoint, BatchedTranspose
55

66
function bmm_test(a,b; transA = false, transB = false)
@@ -119,6 +119,37 @@ end
119119
end
120120
end
121121

122+
@testset "batched_mul: trivial dimensions & unit strides, $T" for T in [Float64, ComplexF64]
123+
@testset "$tA(rand$((sA...,2))) ⊠ $tB(rand$((sB...,2)))" for
124+
tA in [identity, batched_adjoint, batched_transpose], sA in [(1,1), (1,3), (3,1), (3,3)],
125+
tB in [identity, batched_adjoint, batched_transpose], sB in [(1,1), (1,3), (3,1), (3,3)]
126+
127+
A = tA(rand(T, sA..., 2))
128+
B = tB(rand(T, sB..., 2))
129+
size(A,2) == size(B,1) || continue
130+
131+
C = cat(A[:,:,1] * B[:,:,1], A[:,:,2] * B[:,:,2]; dims=3)
132+
@test A B C
133+
134+
# In-place batched_mul!
135+
α, β = rand(T), rand(T)
136+
D = rand(T, size(C))
137+
@test batched_mul!(copy(D), A, B, α, β) α .* C .+ β .* D
138+
@test batched_mul_generic!(copy(D), A, B, α, β) α .* C .+ β .* D
139+
140+
# ... and with weird LHS -- all to batched_mul_generic! right now
141+
C2 = batched_transpose(permutedims(C, (2,1,3)))
142+
C3 = batched_adjoint(permutedims(conj(C), (2,1,3)))
143+
@test C2 == C3 == C
144+
C2 .= D
145+
C3 .= D
146+
@test batched_mul!(C2, A, B, α, β) α .* C .+ β .* D
147+
@test C2 α .* C .+ β .* D
148+
@test batched_mul!(C3, A, B, α, β) α .* C .+ β .* D
149+
@test C3 α .* C .+ β .* D
150+
end
151+
end
152+
122153
@testset "BatchedAdjOrTrans interface * $TB" for TB in [Float64, Float32]
123154
A = randn(7,5,3)
124155
B = randn(TB, 5,7,3)

0 commit comments

Comments
 (0)