@@ -3,6 +3,7 @@ using Flux: throttle, nfan, glorot_uniform, glorot_normal,
3
3
kaiming_normal, kaiming_uniform, orthogonal, truncated_normal,
4
4
sparse_init, identity_init, unstack, batch, unbatch,
5
5
unsqueeze, params, loadparams!, loadmodel!
6
+ using MLUtils
6
7
using StatsBase: var, std
7
8
using Statistics, LinearAlgebra
8
9
using Random
@@ -326,14 +327,14 @@ end
326
327
327
328
@testset " Stacking" begin
328
329
x = randn (3 ,3 )
329
- stacked = Flux . MLUtils. stack ([x, x], dims= 2 )
330
+ stacked = MLUtils. stack ([x, x], dims= 2 )
330
331
@test size (stacked) == (3 ,2 ,3 )
331
332
332
333
stacked_array= [ 8 9 3 5 ; 9 6 6 9 ; 9 1 7 2 ; 7 4 10 6 ]
333
334
unstacked_array= [[8 , 9 , 9 , 7 ], [9 , 6 , 1 , 4 ], [3 , 6 , 7 , 10 ], [5 , 9 , 2 , 6 ]]
334
335
@test unstack (stacked_array, dims= 2 ) == unstacked_array
335
- @test Flux . MLUtils. stack (unstacked_array, dims= 2 ) == stacked_array
336
- @test Flux . MLUtils. stack (unstack (stacked_array, dims= 1 ), dims= 1 ) == stacked_array
336
+ @test MLUtils. stack (unstacked_array, dims= 2 ) == stacked_array
337
+ @test MLUtils. stack (unstack (stacked_array, dims= 1 ), dims= 1 ) == stacked_array
337
338
end
338
339
339
340
@testset " Batching" begin
0 commit comments