Skip to content

Commit 70d639d

Browse files
committed
Format tests directory
1 parent d5d28f0 commit 70d639d

File tree

4 files changed

+160
-168
lines changed

4 files changed

+160
-168
lines changed

test/convnets.jl

Lines changed: 122 additions & 130 deletions
Original file line numberDiff line numberDiff line change
@@ -5,202 +5,194 @@ using Flux
55
PRETRAINED_MODELS = []
66

77
@testset "AlexNet" begin
8-
model = AlexNet()
9-
@test size(model(x_256)) == (1000, 1)
10-
@test_throws ArgumentError AlexNet(pretrain = true)
11-
@test gradtest(model, x_256)
8+
model = AlexNet()
9+
@test size(model(x_256)) == (1000, 1)
10+
@test_throws ArgumentError AlexNet(pretrain = true)
11+
@test gradtest(model, x_256)
1212
end
1313

1414
GC.safepoint()
1515
GC.gc()
1616

1717
@testset "VGG" begin
18-
@testset "VGG($sz, batchnorm=$bn)" for sz in [11, 13, 16, 19], bn in [true, false]
19-
m = VGG(sz, batchnorm = bn)
20-
21-
@test size(m(x_224)) == (1000, 1)
22-
if (VGG, sz, bn) in PRETRAINED_MODELS
23-
@test (VGG(sz, batchnorm = bn, pretrain = true); true)
24-
else
25-
@test_throws ArgumentError VGG(sz, batchnorm = bn, pretrain = true)
18+
@testset "VGG($sz, batchnorm=$bn)" for sz in [11, 13, 16, 19], bn in [true, false]
19+
m = VGG(sz, batchnorm = bn)
20+
@test size(m(x_224)) == (1000, 1)
21+
if (VGG, sz, bn) in PRETRAINED_MODELS
22+
@test (VGG(sz, batchnorm = bn, pretrain = true); true)
23+
else
24+
@test_throws ArgumentError VGG(sz, batchnorm = bn, pretrain = true)
25+
end
26+
@test gradtest(m, x_224)
27+
GC.safepoint()
28+
GC.gc()
2629
end
27-
@test gradtest(m, x_224)
28-
GC.safepoint()
29-
GC.gc()
30-
end
3130
end
3231

3332
GC.safepoint()
3433
GC.gc()
3534

3635
@testset "ResNet" begin
37-
@testset "ResNet($sz)" for sz in [18, 34, 50, 101, 152]
38-
m = ResNet(sz)
39-
40-
@test size(m(x_256)) == (1000, 1)
41-
if (ResNet, sz) in PRETRAINED_MODELS
42-
@test (ResNet(sz, pretrain = true); true)
43-
else
44-
@test_throws ArgumentError ResNet(sz, pretrain = true)
36+
@testset "ResNet($sz)" for sz in [18, 34, 50, 101, 152]
37+
m = ResNet(sz)
38+
@test size(m(x_256)) == (1000, 1)
39+
if (ResNet, sz) in PRETRAINED_MODELS
40+
@test (ResNet(sz, pretrain = true); true)
41+
else
42+
@test_throws ArgumentError ResNet(sz, pretrain = true)
43+
end
44+
@test gradtest(m, x_256)
45+
GC.safepoint()
46+
GC.gc()
4547
end
46-
@test gradtest(m, x_256)
47-
GC.safepoint()
48-
GC.gc()
49-
end
50-
51-
@testset "Shortcut C" begin
52-
m = Metalhead.resnet(Metalhead.basicblock, :C;
53-
channel_config = [1, 1],
54-
block_config = [2, 2, 2, 2])
5548

56-
@test size(m(x_256)) == (1000, 1)
57-
@test gradtest(m, x_256)
58-
end
49+
@testset "Shortcut C" begin
50+
m = Metalhead.resnet(Metalhead.basicblock, :C;
51+
channel_config = [1, 1],
52+
block_config = [2, 2, 2, 2])
53+
@test size(m(x_256)) == (1000, 1)
54+
@test gradtest(m, x_256)
55+
end
5956
end
6057

6158
GC.safepoint()
6259
GC.gc()
6360

6461
@testset "ResNeXt" begin
65-
@testset for depth in [50, 101, 152]
66-
m = ResNeXt(depth)
67-
68-
@test size(m(x_224)) == (1000, 1)
69-
if ResNeXt in PRETRAINED_MODELS
70-
@test (ResNeXt(depth, pretrain = true); true)
71-
else
72-
@test_throws ArgumentError ResNeXt(depth, pretrain = true)
62+
@testset for depth in [50, 101, 152]
63+
m = ResNeXt(depth)
64+
@test size(m(x_224)) == (1000, 1)
65+
if ResNeXt in PRETRAINED_MODELS
66+
@test (ResNeXt(depth, pretrain = true); true)
67+
else
68+
@test_throws ArgumentError ResNeXt(depth, pretrain = true)
69+
end
70+
@test gradtest(m, x_224)
71+
GC.safepoint()
72+
GC.gc()
7373
end
74-
@test gradtest(m, x_224)
75-
GC.safepoint()
76-
GC.gc()
77-
end
7874
end
7975

8076
GC.safepoint()
8177
GC.gc()
8278

8379
@testset "GoogLeNet" begin
84-
m = GoogLeNet()
85-
@test size(m(x_224)) == (1000, 1)
86-
@test_throws ArgumentError (GoogLeNet(pretrain = true); true)
87-
@test gradtest(m, x_224)
80+
m = GoogLeNet()
81+
@test size(m(x_224)) == (1000, 1)
82+
@test_throws ArgumentError (GoogLeNet(pretrain = true); true)
83+
@test gradtest(m, x_224)
8884
end
8985

9086
GC.safepoint()
9187
GC.gc()
9288

9389
@testset "Inception3" begin
94-
m = Inception3()
95-
@test size(m(x_224)) == (1000, 1)
96-
@test_throws ArgumentError Inception3(pretrain = true)
97-
@test gradtest(m, x_224)
90+
m = Inception3()
91+
@test size(m(x_224)) == (1000, 1)
92+
@test_throws ArgumentError Inception3(pretrain = true)
93+
@test gradtest(m, x_224)
9894
end
9995

10096
GC.safepoint()
10197
GC.gc()
10298

10399
@testset "SqueezeNet" begin
104-
m = SqueezeNet()
105-
@test size(m(x_224)) == (1000, 1)
106-
@test_throws ArgumentError (SqueezeNet(pretrain = true); true)
107-
@test gradtest(m, x_224)
100+
m = SqueezeNet()
101+
@test size(m(x_224)) == (1000, 1)
102+
@test_throws ArgumentError (SqueezeNet(pretrain = true); true)
103+
@test gradtest(m, x_224)
108104
end
109105

110106
GC.safepoint()
111107
GC.gc()
112108

113109
@testset "DenseNet" begin
114-
@testset for sz in [121, 161, 169, 201]
115-
m = DenseNet(sz)
116-
117-
@test size(m(x_224)) == (1000, 1)
118-
if (DenseNet, sz) in PRETRAINED_MODELS
119-
@test (DenseNet(sz, pretrain = true); true)
120-
else
121-
@test_throws ArgumentError DenseNet(sz, pretrain = true)
110+
@testset for sz in [121, 161, 169, 201]
111+
m = DenseNet(sz)
112+
@test size(m(x_224)) == (1000, 1)
113+
if (DenseNet, sz) in PRETRAINED_MODELS
114+
@test (DenseNet(sz, pretrain = true); true)
115+
else
116+
@test_throws ArgumentError DenseNet(sz, pretrain = true)
117+
end
118+
@test gradtest(m, x_224)
119+
GC.safepoint()
120+
GC.gc()
122121
end
123-
@test gradtest(m, x_224)
124-
GC.safepoint()
125-
GC.gc()
126-
end
127122
end
128123

129124
GC.safepoint()
130125
GC.gc()
131126

132127
@testset "MobileNet" verbose = true begin
133-
@testset "MobileNetv1" begin
134-
m = MobileNetv1()
135-
136-
@test size(m(x_224)) == (1000, 1)
137-
if MobileNetv1 in PRETRAINED_MODELS
138-
@test (MobileNetv1(pretrain = true); true)
139-
else
140-
@test_throws ArgumentError MobileNetv1(pretrain = true)
128+
@testset "MobileNetv1" begin
129+
m = MobileNetv1()
130+
@test size(m(x_224)) == (1000, 1)
131+
if MobileNetv1 in PRETRAINED_MODELS
132+
@test (MobileNetv1(pretrain = true); true)
133+
else
134+
@test_throws ArgumentError MobileNetv1(pretrain = true)
135+
end
136+
@test gradtest(m, x_224)
141137
end
142-
@test gradtest(m, x_224)
143-
end
144138

145-
GC.safepoint()
146-
GC.gc()
139+
GC.safepoint()
140+
GC.gc()
141+
142+
@testset "MobileNetv2" begin
143+
m = MobileNetv2()
144+
@test size(m(x_224)) == (1000, 1)
145+
if MobileNetv2 in PRETRAINED_MODELS
146+
@test (MobileNetv2(pretrain = true); true)
147+
else
148+
@test_throws ArgumentError MobileNetv2(pretrain = true)
149+
end
150+
@test gradtest(m, x_224)
151+
end
147152

148-
@testset "MobileNetv2" begin
149-
m = MobileNetv2()
153+
GC.safepoint()
154+
GC.gc()
150155

151-
@test size(m(x_224)) == (1000, 1)
152-
if MobileNetv2 in PRETRAINED_MODELS
153-
@test (MobileNetv2(pretrain = true); true)
154-
else
155-
@test_throws ArgumentError MobileNetv2(pretrain = true)
156+
@testset "MobileNetv3" verbose = true begin
157+
@testset for mode in [:small, :large]
158+
m = MobileNetv3(mode)
159+
160+
@test size(m(x_224)) == (1000, 1)
161+
if MobileNetv3 in PRETRAINED_MODELS
162+
@test (MobileNetv3(mode; pretrain = true); true)
163+
else
164+
@test_throws ArgumentError MobileNetv3(mode; pretrain = true)
165+
end
166+
@test gradtest(m, x_224)
167+
end
156168
end
157-
@test gradtest(m, x_224)
158-
end
159-
160-
GC.safepoint()
161-
GC.gc()
162-
163-
@testset "MobileNetv3" verbose = true begin
164-
@testset for mode in [:small, :large]
165-
m = MobileNetv3(mode)
166-
167-
@test size(m(x_224)) == (1000, 1)
168-
if MobileNetv3 in PRETRAINED_MODELS
169-
@test (MobileNetv3(mode; pretrain = true); true)
170-
else
171-
@test_throws ArgumentError MobileNetv3(mode; pretrain = true)
172-
end
173-
@test gradtest(m, x_224)
174169
end
175-
end
176-
end
177-
178-
GC.safepoint()
179-
GC.gc()
180170

181-
@testset "ConvNeXt" verbose = true begin
182-
@testset for mode in [:small, :base, :large] # :tiny, #, :xlarge]
183-
@testset for drop_path_rate in [0.0, 0.5]
184-
m = ConvNeXt(mode; drop_path_rate)
171+
GC.safepoint()
172+
GC.gc()
185173

186-
@test size(m(x_224)) == (1000, 1)
187-
@test gradtest(m, x_224)
188-
GC.safepoint()
189-
GC.gc()
190-
end
191-
end
174+
@testset "ConvNeXt" verbose = true begin
175+
@testset for mode in [:small, :base, :large] # :tiny, #, :xlarge]
176+
@testset for drop_path_rate in [0.0, 0.5]
177+
m = ConvNeXt(mode; drop_path_rate)
178+
@test size(m(x_224)) == (1000, 1)
179+
@test gradtest(m, x_224)
180+
GC.safepoint()
181+
GC.gc()
182+
end
183+
end
192184
end
193185

194186
GC.safepoint()
195187
GC.gc()
196188

197189
@testset "ConvMixer" verbose = true begin
198-
@testset for mode in [:small, :base, :large]
199-
m = ConvMixer(mode)
190+
@testset for mode in [:small, :base, :large]
191+
m = ConvMixer(mode)
200192

201-
@test size(m(x_224)) == (1000, 1)
202-
@test gradtest(m, x_224)
203-
GC.safepoint()
204-
GC.gc()
205-
end
193+
@test size(m(x_224)) == (1000, 1)
194+
@test gradtest(m, x_224)
195+
GC.safepoint()
196+
GC.gc()
197+
end
206198
end

test/other.jl

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2,37 +2,37 @@ using Metalhead, Test
22
using Flux
33

44
@testset "MLPMixer" begin
5-
@testset for mode in [:small, :base, :large] # :huge]
6-
@testset for drop_path_rate in [0.0, 0.5]
7-
m = MLPMixer(mode; drop_path_rate)
8-
@test size(m(x_224)) == (1000, 1)
9-
@test gradtest(m, x_224)
10-
GC.safepoint()
11-
GC.gc()
12-
end
13-
end
5+
@testset for mode in [:small, :base, :large] # :huge]
6+
@testset for drop_path_rate in [0.0, 0.5]
7+
m = MLPMixer(mode; drop_path_rate)
8+
@test size(m(x_224)) == (1000, 1)
9+
@test gradtest(m, x_224)
10+
GC.safepoint()
11+
GC.gc()
12+
end
13+
end
1414
end
1515

1616
@testset "ResMLP" begin
1717
@testset for mode in [:small, :base, :large] # :huge]
18-
@testset for drop_path_rate in [0.0, 0.5]
19-
m = ResMLP(mode; drop_path_rate)
20-
@test size(m(x_224)) == (1000, 1)
21-
@test gradtest(m, x_224)
22-
GC.safepoint()
23-
GC.gc()
24-
end
18+
@testset for drop_path_rate in [0.0, 0.5]
19+
m = ResMLP(mode; drop_path_rate)
20+
@test size(m(x_224)) == (1000, 1)
21+
@test gradtest(m, x_224)
22+
GC.safepoint()
23+
GC.gc()
24+
end
2525
end
2626
end
2727

2828
@testset "gMLP" begin
29-
@testset for mode in [:small, :base, :large] # :huge]
30-
@testset for drop_path_rate in [0.0, 0.5]
31-
m = gMLP(mode; drop_path_rate)
32-
@test size(m(x_224)) == (1000, 1)
33-
@test gradtest(m, x_224)
34-
GC.safepoint()
35-
GC.gc()
29+
@testset for mode in [:small, :base, :large] # :huge]
30+
@testset for drop_path_rate in [0.0, 0.5]
31+
m = gMLP(mode; drop_path_rate)
32+
@test size(m(x_224)) == (1000, 1)
33+
@test gradtest(m, x_224)
34+
GC.safepoint()
35+
GC.gc()
36+
end
3637
end
37-
end
3838
end

0 commit comments

Comments
 (0)