6
6
7
7
@pytest .mark .parametrize ("long_skip" , ["unet" , "unetpp" , "unet3p" , "unet3p-lite" ])
8
8
@pytest .mark .parametrize ("merge_policy" , ["cat" , "sum" ])
9
- def test_decoder_fwdbwd (long_skip , merge_policy ):
10
- enc_channels = (64 , 32 , 16 , 8 , 8 )
11
- out_dims = [256 // 2 ** i for i in range (6 )][::- 1 ]
9
+ @pytest .mark .parametrize ("use_conv" , [True , False ])
10
+ @pytest .mark .parametrize ("use_tr" , [True , False ])
11
+ def test_decoder_fwdbwd (long_skip , merge_policy , use_conv , use_tr ):
12
+ enc_channels = (64 , 32 , 64 , 32 )
13
+ out_dims = [32 // 2 ** i for i in range (4 )][::- 1 ]
12
14
13
15
decoder1_kwargs = {"merge_policy" : merge_policy }
14
16
decoder2_kwargs = {"merge_policy" : merge_policy }
@@ -23,17 +25,30 @@ def test_decoder_fwdbwd(long_skip, merge_policy):
23
25
decoder5_kwargs ,
24
26
)
25
27
28
+ n_layers = None
29
+ n_blocks = None
30
+ if use_conv :
31
+ n_layers = (1 , 1 , 1 , 1 )
32
+ n_blocks = ((2 ,), (2 ,), (2 ,), (2 ,))
33
+
34
+ n_tr_layers = None
35
+ n_tr_blocks = None
36
+ if use_tr :
37
+ n_tr_layers = (1 , 1 , 1 , 1 )
38
+ n_tr_blocks = ((1 ,), (1 ,), (1 ,), (1 ,))
39
+
26
40
decoder = Decoder (
27
41
enc_channels = enc_channels ,
28
- model_input_size = 256 ,
29
- out_channels = (64 , 32 , 16 , 8 , 8 ),
30
- n_layers = (1 , 1 , 1 , 1 , 1 ),
31
- n_blocks = ((2 ,), (2 ,), (2 ,), (2 ,), (2 ,)),
42
+ out_channels = (32 , 32 , 32 , 32 ),
43
+ n_layers = n_layers ,
44
+ n_blocks = n_blocks ,
45
+ n_transformers = n_tr_layers ,
46
+ n_transformer_blocks = n_tr_blocks ,
32
47
long_skip = long_skip ,
33
48
stage_params = stage_params ,
34
49
)
35
50
36
- x = [torch .rand ([1 , enc_channels [i ], out_dims [i ], out_dims [i ]]) for i in range (5 )]
51
+ x = [torch .rand ([1 , enc_channels [i ], out_dims [i ], out_dims [i ]]) for i in range (4 )]
37
52
out = decoder (* x )
38
53
39
54
out [- 1 ].mean ().backward ()
0 commit comments