@@ -310,6 +310,49 @@ ddims(x) = dropdims(x, dims=(ndims(x)-1, ndims(x)))
310
310
end
311
311
end
312
312
313
+ # Test all in-place implementations/interfaces
314
+ convs = [NNlib. conv!, NNlib. conv_im2col!, NNlib. conv_direct!,]
315
+ NNlib. is_nnpack_available () && push! (convs, NNlib. conv_nnpack!)
316
+ for conv! in convs
317
+ if NNlib. is_nnpack_available ()
318
+ if conv! == NNlib. conv_nnpack! && ! NNlib. nnpack_supported_operation (DenseConvDims (x, w))
319
+ continue
320
+ end
321
+ end
322
+ α, β = 2e0 , - 1e0
323
+
324
+ @testset " $(conv!) " begin
325
+ # First, your basic convolution with no parameters
326
+ cdims = DenseConvDims (x, w)
327
+ y0 = rand (rng, - 9e0 : 9e0 , size (y_plain)... , 1 , 1 )
328
+ @test isapprox (ddims (conv! (copy (y0), x, w, cdims; alpha= α, beta= β)), α* y_plain + β* y0, rtol = 1.0e-7 )
329
+
330
+ # Next, test convolution on views and alternate datatypes:
331
+ @test isapprox (ddims (conv! (copy (y0), view (x, repeat ([:], ndims (x))... ), w, cdims; alpha= α, beta= β)), α* y_plain + β* y0, rtol = 1.0e-7 )
332
+ @test isapprox (ddims (conv! (Float32 .(copy (y0)), Float32 .(x), Float32 .(w), cdims; alpha= Float32 (α), beta= Float32 (β))), Float32 .(α* y_plain + β* y0), rtol = 1.0e-7 )
333
+
334
+ # Next, introduce stride:
335
+ cdims = DenseConvDims (x, w; stride= 2 )
336
+ y0 = rand (rng, - 9e0 : 9e0 , size (y_stride)... , 1 , 1 )
337
+ @test isapprox (ddims (conv! (copy (y0), x, w, cdims; alpha= α, beta= β)), α* y_stride + β* y0, rtol = 1.0e-7 )
338
+
339
+ # Next, introduce dilation:
340
+ cdims = DenseConvDims (x, w; dilation= 2 )
341
+ y0 = rand (rng, - 9e0 : 9e0 , size (y_dil)... , 1 , 1 )
342
+ @test isapprox (ddims (conv! (copy (y0), x, w, cdims; alpha= α, beta= β)), α* y_dil + β* y0, rtol = 1.0e-7 )
343
+
344
+ # Next, introduce padding:
345
+ cdims = DenseConvDims (x, w; padding= 1 )
346
+ y0 = rand (rng, - 9e0 : 9e0 , size (y_pad)... , 1 , 1 )
347
+ @test isapprox (ddims (conv! (copy (y0), x, w, cdims; alpha= α, beta= β)), α* y_pad + β* y0, rtol = 1.0e-7 )
348
+
349
+ # Next, test crosscor/conv with a flipped kernel
350
+ cdims = DenseConvDims (x, w; flipkernel= true )
351
+ y0 = rand (rng, - 9e0 : 9e0 , size (y_flip)... , 1 , 1 )
352
+ @test isapprox (ddims (conv! (copy (y0), x, w, cdims; alpha= α, beta= β)), α* y_flip + β* y0, rtol = 1.0e-7 )
353
+ end
354
+ end
355
+
313
356
# Test all implementations/interfaces
314
357
for (∇conv_filter, ∇conv_data) in (
315
358
(NNlib.∇conv_filter, NNlib.∇conv_data),
@@ -355,6 +398,58 @@ ddims(x) = dropdims(x, dims=(ndims(x)-1, ndims(x)))
355
398
@test isapprox (ddims (∇conv_data (dy, w, cdims)), dx_flip, rtol = 1.0e-7 )
356
399
end
357
400
end
401
+
402
+ # Test all in-place implementations/interfaces
403
+ for (∇conv_filter!, ∇conv_data!) in (
404
+ (NNlib.∇conv_filter!, NNlib.∇conv_data!),
405
+ (NNlib.∇conv_filter_im2col!, NNlib.∇conv_data_im2col!),
406
+ (NNlib.∇conv_filter_direct!, NNlib.∇conv_data_direct!),
407
+ )
408
+ # α, β = 2*rand(rng) - 1, 2*rand(rng) - 1
409
+ α, β = 2e0 , - 1e0
410
+ flag = ∇conv_data! in (NNlib.∇conv_data!, NNlib.∇conv_data_im2col!)
411
+
412
+ @testset " $(∇conv_filter!) /$(∇conv_data!) " begin
413
+ # First, your basic convolution with no parameters
414
+ cdims = DenseConvDims (x, w)
415
+ dy = NNlib. conv (x, w, cdims)
416
+ @test isapprox (ddims (∇conv_filter! (copy (w), x, dy, cdims; alpha= α, beta= β)), α* dw + β* w, rtol = 1.0e-7 )
417
+ @test isapprox (ddims (∇conv_data! (copy (x), dy, w, cdims; alpha= α, beta= β)), α* dx + β* x, rtol = 1.0e-7 ) broken= flag
418
+
419
+ # Next, test convolution on views and alternate datatypes:
420
+ @test isapprox (ddims (∇conv_filter! (copy (w), x, view (dy, repeat ([:], ndims (dy))... ), cdims; alpha= α, beta= β)), α* dw + β* w, rtol = 1.0e-7 )
421
+ @test isapprox (ddims (∇conv_data! (copy (x), view (dy, repeat ([:], ndims (dy))... ), w, cdims; alpha= α, beta= β)), α* dx + β* x, rtol = 1.0e-7 ) broken= flag
422
+
423
+ @test isapprox (ddims (∇conv_filter! (Float32 .(copy (w)), Float32 .(x), Float32 .(dy), cdims; alpha= Float32 (α), beta= Float32 (β))), α* dw + β* w, rtol = 1.0e-7 )
424
+ @test isapprox (ddims (∇conv_data! (Float32 .(copy (x)), Float32 .(dy), Float32 .(w), cdims; alpha= Float32 (α), beta= Float32 (β))), α* dx + β* x, rtol = 1.0e-7 ) broken= flag
425
+
426
+ # Next, introduce stride:
427
+ cdims = DenseConvDims (x, w; stride= 2 )
428
+ dy = NNlib. conv (x, w, cdims)
429
+ flag_ = ∇conv_filter! == NNlib.∇conv_filter_direct! && rank in (1 ,3 )
430
+ @test isapprox (ddims (∇conv_filter! (copy (w), x, dy, cdims; alpha= α, beta= β)), α* dw_stride + β* w, rtol = 1.0e-7 ) broken= flag_
431
+ @test isapprox (ddims (∇conv_data! (copy (x), dy, w, cdims; alpha= α, beta= β)), α* dx_stride + β* x, rtol = 1.0e-7 ) broken= flag
432
+
433
+ # Next, introduce dilation:
434
+ cdims = DenseConvDims (x, w; dilation= 2 )
435
+ dy = NNlib. conv (x, w, cdims)
436
+ flag_ = ∇conv_data! == NNlib.∇conv_data_direct! && rank == 3
437
+ @test isapprox (ddims (∇conv_filter! (copy (w), x, dy, cdims; alpha= α, beta= β)), α* dw_dil + β* w, rtol = 1.0e-7 )
438
+ @test isapprox (ddims (∇conv_data! (copy (x), dy, w, cdims; alpha= α, beta= β)), α* dx_dil + β* x, rtol = 1.0e-7 ) broken= flag || flag_
439
+
440
+ # Next, introduce padding:
441
+ cdims = DenseConvDims (x, w; padding= 1 )
442
+ dy = NNlib. conv (x, w, cdims)
443
+ @test isapprox (ddims (∇conv_filter! (copy (w), x, dy, cdims; alpha= α, beta= β)), α* dw_pad + β* w, rtol = 1.0e-7 )
444
+ @test isapprox (ddims (∇conv_data! (copy (x), dy, w, cdims; alpha= α, beta= β)), α* dx_pad + β* x, rtol = 1.0e-7 ) broken= flag
445
+
446
+ # Next, test crosscor/conv with a flipped kernel
447
+ cdims = DenseConvDims (x, w; flipkernel= true )
448
+ dy = NNlib. conv (x, w, cdims)
449
+ @test isapprox (ddims (∇conv_filter! (copy (w), x, dy, cdims; alpha= α, beta= β)), α* dw_flip + β* w, rtol = 1.0e-7 )
450
+ @test isapprox (ddims (∇conv_data! (copy (x), dy, w, cdims; alpha= α, beta= β)), α* dx_flip + β* x, rtol = 1.0e-7 ) broken= flag
451
+ end
452
+ end
358
453
end
359
454
end
360
455
end
0 commit comments