Skip to content

Commit 629475a

Browse files
authored
nonzero beta + flipkernel bugfix (#519)
* nonzero beta + flipkernel bugfix * conv! alpha/beta tests added, conv_filter_direct flipkernel with view
1 parent ace7d53 commit 629475a

File tree

2 files changed

+100
-3
lines changed

2 files changed

+100
-3
lines changed

src/impl/conv_direct.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -191,9 +191,11 @@ function ∇conv_filter_direct!(dw::AbstractArray{wT,5}, x::AbstractArray{xT,5},
191191
dy = transpose_swapbatch(predilate(dy, stride(cdims)))
192192
ctdims = DenseConvDims(dy, x; padding=transpose_pad(cdims),
193193
stride=dilation(cdims))
194-
conv_direct!(dw, dy, x, ctdims; alpha=alpha, beta=beta)
195-
if flipkernel(cdims)
196-
dw .= dw[end:-1:1, end:-1:1, end:-1:1, :, :]
194+
dw_ = if flipkernel(cdims)
195+
view(dw, reverse(axes(dw, 1)), reverse(axes(dw, 2)), reverse(axes(dw, 3)), :, :)
196+
else
197+
dw
197198
end
199+
conv_direct!(dw_, dy, x, ctdims; alpha=alpha, beta=beta)
198200
return dw
199201
end

test/conv.jl

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,49 @@ ddims(x) = dropdims(x, dims=(ndims(x)-1, ndims(x)))
310310
end
311311
end
312312

313+
# Test all in-place implementations/interfaces
314+
convs = [NNlib.conv!, NNlib.conv_im2col!, NNlib.conv_direct!,]
315+
NNlib.is_nnpack_available() && push!(convs, NNlib.conv_nnpack!)
316+
for conv! in convs
317+
if NNlib.is_nnpack_available()
318+
if conv! == NNlib.conv_nnpack! && !NNlib.nnpack_supported_operation(DenseConvDims(x, w))
319+
continue
320+
end
321+
end
322+
α, β = 2e0, -1e0
323+
324+
@testset "$(conv!)" begin
325+
# First, your basic convolution with no parameters
326+
cdims = DenseConvDims(x, w)
327+
y0 = rand(rng, -9e0:9e0, size(y_plain)..., 1, 1)
328+
@test isapprox(ddims(conv!(copy(y0), x, w, cdims; alpha=α, beta=β)), α*y_plain + β*y0, rtol = 1.0e-7)
329+
330+
# Next, test convolution on views and alternate datatypes:
331+
@test isapprox(ddims(conv!(copy(y0), view(x, repeat([:], ndims(x))...), w, cdims; alpha=α, beta=β)), α*y_plain + β*y0, rtol = 1.0e-7)
332+
@test isapprox(ddims(conv!(Float32.(copy(y0)), Float32.(x), Float32.(w), cdims; alpha=Float32(α), beta=Float32(β))), Float32.(α*y_plain + β*y0), rtol = 1.0e-7)
333+
334+
# Next, introduce stride:
335+
cdims = DenseConvDims(x, w; stride=2)
336+
y0 = rand(rng, -9e0:9e0, size(y_stride)..., 1, 1)
337+
@test isapprox(ddims(conv!(copy(y0), x, w, cdims; alpha=α, beta=β)), α*y_stride + β*y0, rtol = 1.0e-7)
338+
339+
# Next, introduce dilation:
340+
cdims = DenseConvDims(x, w; dilation=2)
341+
y0 = rand(rng, -9e0:9e0, size(y_dil)..., 1, 1)
342+
@test isapprox(ddims(conv!(copy(y0), x, w, cdims; alpha=α, beta=β)), α*y_dil + β*y0, rtol = 1.0e-7)
343+
344+
# Next, introduce padding:
345+
cdims = DenseConvDims(x, w; padding=1)
346+
y0 = rand(rng, -9e0:9e0, size(y_pad)..., 1, 1)
347+
@test isapprox(ddims(conv!(copy(y0), x, w, cdims; alpha=α, beta=β)), α*y_pad + β*y0, rtol = 1.0e-7)
348+
349+
# Next, test crosscor/conv with a flipped kernel
350+
cdims = DenseConvDims(x, w; flipkernel=true)
351+
y0 = rand(rng, -9e0:9e0, size(y_flip)..., 1, 1)
352+
@test isapprox(ddims(conv!(copy(y0), x, w, cdims; alpha=α, beta=β)), α*y_flip + β*y0, rtol = 1.0e-7)
353+
end
354+
end
355+
313356
# Test all implementations/interfaces
314357
for (∇conv_filter, ∇conv_data) in (
315358
(NNlib.∇conv_filter, NNlib.∇conv_data),
@@ -355,6 +398,58 @@ ddims(x) = dropdims(x, dims=(ndims(x)-1, ndims(x)))
355398
@test isapprox(ddims(∇conv_data(dy, w, cdims)), dx_flip, rtol = 1.0e-7)
356399
end
357400
end
401+
402+
# Test all in-place implementations/interfaces
403+
for (∇conv_filter!, ∇conv_data!) in (
404+
(NNlib.∇conv_filter!, NNlib.∇conv_data!),
405+
(NNlib.∇conv_filter_im2col!, NNlib.∇conv_data_im2col!),
406+
(NNlib.∇conv_filter_direct!, NNlib.∇conv_data_direct!),
407+
)
408+
#α, β = 2*rand(rng) - 1, 2*rand(rng) - 1
409+
α, β = 2e0, -1e0
410+
flag = ∇conv_data! in (NNlib.∇conv_data!, NNlib.∇conv_data_im2col!)
411+
412+
@testset "$(∇conv_filter!)/$(∇conv_data!)" begin
413+
# First, your basic convolution with no parameters
414+
cdims = DenseConvDims(x, w)
415+
dy = NNlib.conv(x, w, cdims)
416+
@test isapprox(ddims(∇conv_filter!(copy(w), x, dy, cdims; alpha=α, beta=β)), α*dw + β*w, rtol = 1.0e-7)
417+
@test isapprox(ddims(∇conv_data!(copy(x), dy, w, cdims; alpha=α, beta=β)), α*dx + β*x, rtol = 1.0e-7) broken=flag
418+
419+
# Next, test convolution on views and alternate datatypes:
420+
@test isapprox(ddims(∇conv_filter!(copy(w), x, view(dy, repeat([:], ndims(dy))...), cdims; alpha=α, beta=β)), α*dw + β*w, rtol = 1.0e-7)
421+
@test isapprox(ddims(∇conv_data!(copy(x), view(dy, repeat([:], ndims(dy))...), w, cdims; alpha=α, beta=β)), α*dx + β*x, rtol = 1.0e-7) broken=flag
422+
423+
@test isapprox(ddims(∇conv_filter!(Float32.(copy(w)), Float32.(x), Float32.(dy), cdims; alpha=Float32(α), beta=Float32(β))), α*dw + β*w, rtol = 1.0e-7)
424+
@test isapprox(ddims(∇conv_data!(Float32.(copy(x)), Float32.(dy), Float32.(w), cdims; alpha=Float32(α), beta=Float32(β))), α*dx + β*x, rtol = 1.0e-7) broken=flag
425+
426+
# Next, introduce stride:
427+
cdims = DenseConvDims(x, w; stride=2)
428+
dy = NNlib.conv(x, w, cdims)
429+
flag_ = ∇conv_filter! == NNlib.∇conv_filter_direct! && rank in (1,3)
430+
@test isapprox(ddims(∇conv_filter!(copy(w), x, dy, cdims; alpha=α, beta=β)), α*dw_stride + β*w, rtol = 1.0e-7) broken=flag_
431+
@test isapprox(ddims(∇conv_data!(copy(x), dy, w, cdims; alpha=α, beta=β)), α*dx_stride + β*x, rtol = 1.0e-7) broken=flag
432+
433+
# Next, introduce dilation:
434+
cdims = DenseConvDims(x, w; dilation=2)
435+
dy = NNlib.conv(x, w, cdims)
436+
flag_ = ∇conv_data! == NNlib.∇conv_data_direct! && rank == 3
437+
@test isapprox(ddims(∇conv_filter!(copy(w), x, dy, cdims; alpha=α, beta=β)), α*dw_dil + β*w, rtol = 1.0e-7)
438+
@test isapprox(ddims(∇conv_data!(copy(x), dy, w, cdims; alpha=α, beta=β)), α*dx_dil + β*x, rtol = 1.0e-7) broken=flag || flag_
439+
440+
# Next, introduce padding:
441+
cdims = DenseConvDims(x, w; padding=1)
442+
dy = NNlib.conv(x, w, cdims)
443+
@test isapprox(ddims(∇conv_filter!(copy(w), x, dy, cdims; alpha=α, beta=β)), α*dw_pad + β*w, rtol = 1.0e-7)
444+
@test isapprox(ddims(∇conv_data!(copy(x), dy, w, cdims; alpha=α, beta=β)), α*dx_pad + β*x, rtol = 1.0e-7) broken=flag
445+
446+
# Next, test crosscor/conv with a flipped kernel
447+
cdims = DenseConvDims(x, w; flipkernel=true)
448+
dy = NNlib.conv(x, w, cdims)
449+
@test isapprox(ddims(∇conv_filter!(copy(w), x, dy, cdims; alpha=α, beta=β)), α*dw_flip + β*w, rtol = 1.0e-7)
450+
@test isapprox(ddims(∇conv_data!(copy(x), dy, w, cdims; alpha=α, beta=β)), α*dx_flip + β*x, rtol = 1.0e-7) broken=flag
451+
end
452+
end
358453
end
359454
end
360455
end

0 commit comments

Comments
 (0)