|
1 | 1 | using Flux
|
2 | 2 | using Flux: throttle, nfan, glorot_uniform, glorot_normal,
|
3 |
| - kaiming_normal, kaiming_uniform, orthogonal, |
4 |
| - sparse_init, stack, unstack, Zeros, batch, unbatch |
| 3 | + kaiming_normal, kaiming_uniform, orthogonal, |
| 4 | + sparse_init, stack, unstack, Zeros, batch, unbatch, |
| 5 | + unsqueeze |
5 | 6 | using StatsBase: var, std
|
6 | 7 | using Random
|
7 | 8 | using Test
|
8 | 9 |
|
| 10 | +@testset "unsqueeze" begin |
| 11 | + x = randn(2, 3, 2) |
| 12 | + @test @inferred(unsqueeze(x, 1)) == reshape(x, 1, 2, 3, 2) |
| 13 | + @test @inferred(unsqueeze(x, 2)) == reshape(x, 2, 1, 3, 2) |
| 14 | + @test @inferred(unsqueeze(x, 3)) == reshape(x, 2, 3, 1, 2) |
| 15 | + @test @inferred(unsqueeze(x, 4)) == reshape(x, 2, 3, 2, 1) |
| 16 | +end |
| 17 | + |
9 | 18 | @testset "Throttle" begin
|
10 | 19 | @testset "default behaviour" begin
|
11 | 20 | a = []
|
@@ -178,10 +187,10 @@ end
|
178 | 187 |
|
179 | 188 | @testset "$layer ID mapping with kernelsize $kernelsize" for layer in (Conv, ConvTranspose, CrossCor), kernelsize in (
|
180 | 189 | (1,),
|
181 |
| - (3,), |
182 |
| - (1, 3), |
183 |
| - (3, 5), |
184 |
| - (3, 5, 7)) |
| 190 | + (3,), |
| 191 | + (1, 3), |
| 192 | + (3, 5), |
| 193 | + (3, 5, 7)) |
185 | 194 | nch = 3
|
186 | 195 | l = layer(kernelsize, nch=>nch, init=identity_init, pad=SamePad())
|
187 | 196 |
|
|
333 | 342 |
|
334 | 343 |
|
335 | 344 | @testset "Batching" begin
|
336 |
| - stacked_array=[ 8 9 3 5 |
337 |
| - 9 6 6 9 |
338 |
| - 9 1 7 2 |
| 345 | + stacked_array=[ 8 9 3 5 |
| 346 | + 9 6 6 9 |
| 347 | + 9 1 7 2 |
339 | 348 | 7 4 10 6 ]
|
340 | 349 | unstacked_array=[[8, 9, 9, 7], [9, 6, 1, 4], [3, 6, 7, 10], [5, 9, 2, 6]]
|
341 | 350 | @test unbatch(stacked_array) == unstacked_array
|
|
445 | 454 |
|
446 | 455 | modules = Flux.modules(Chain(SkipConnection(
|
447 | 456 | Conv((2,3), 4=>5; pad=6, stride=7),
|
448 |
| - +), |
| 457 | + +), |
449 | 458 | LayerNorm(8)))
|
450 | 459 | @test length(modules) == 5
|
451 | 460 | end
|
@@ -475,16 +484,16 @@ end
|
475 | 484 | @testset "early stopping" begin
|
476 | 485 | @testset "args & kwargs" begin
|
477 | 486 | es = Flux.early_stopping((x; y = 1) -> x + y, 10; min_dist=3)
|
478 |
| - |
| 487 | + |
479 | 488 | n_iter = 0
|
480 | 489 | while n_iter < 99
|
481 | 490 | es(-n_iter; y=-n_iter) && break
|
482 | 491 | n_iter += 1
|
483 | 492 | end
|
484 |
| - |
| 493 | + |
485 | 494 | @test n_iter == 9
|
486 | 495 | end
|
487 |
| - |
| 496 | + |
488 | 497 | @testset "distance" begin
|
489 | 498 | es = Flux.early_stopping(identity, 10; distance=(best_score, score) -> score - best_score)
|
490 | 499 |
|
|
496 | 505 |
|
497 | 506 | @test n_iter == 99
|
498 | 507 | end
|
499 |
| - |
| 508 | + |
500 | 509 | @testset "init_score" begin
|
501 | 510 | es = Flux.early_stopping(identity, 10; init_score=10)
|
502 | 511 |
|
|
0 commit comments