Skip to content

Commit ca53acb

Browse files
committed
Fix tests, hopefully
1 parent 8c9f73f commit ca53acb

File tree

4 files changed

+18
-25
lines changed

4 files changed

+18
-25
lines changed

.github/workflows/CI.yml

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,14 @@ jobs:
2727
- x64
2828
suite:
2929
- '["AlexNet", "VGG"]'
30-
- '["GoogLeNet", "SqueezeNet"]'
31-
- '["EfficientNet", "MobileNet"]'
32-
- '[r"/*/ResNet*", "ResNeXt"]'
33-
- 'r"/*/Inception/Inceptionv*"'
34-
- '["InceptionResNetv2", "Xception"]'
30+
- '["GoogLeNet", "SqueezeNet", "MobileNet"]'
31+
- '["EfficientNet"]'
32+
- '[r"ResNet", r"ResNeXt"]'
33+
- '"Inception"'
3534
- '"DenseNet"'
3635
- '["ConvNeXt", "ConvMixer"]'
37-
- '"ViT"'
38-
- '"Mixers"'
36+
- 'r"ViTs"'
37+
- 'r"Mixers"'
3938
steps:
4039
- uses: actions/checkout@v2
4140
- uses: julia-actions/setup-julia@v1

test/convnets.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ end
107107
end
108108

109109
@testset "EfficientNet" begin
110-
@testset "EfficientNet($name)" for name in [:b0, :b1, :b2, :b3, :b4] #, :b5, :b6, :b7, :b8]
110+
@testset "EfficientNet($name)" for name in [:b0, :b1, :b2, :b3, :b4, :b5, :b6, :b7, :b8]
111111
# preferred image resolution scaling
112112
r = Metalhead.efficientnet_global_configs[name][1]
113113
x = rand(Float32, r, r, 3, 1)

test/mixers.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,14 @@
1010
end
1111

1212
@testset "ResMLP" begin
13-
@testset for mode in [:small, :base, :large, :huge]
14-
@testset for drop_path_rate in [0.0, 0.5]
15-
m = ResMLP(mode; drop_path_rate)
16-
@test size(m(x_224)) == (1000, 1)
17-
@test gradtest(m, x_224)
18-
_gc()
19-
end
20-
end
13+
@testset for mode in [:small, :base, :large, :huge]
14+
@testset for drop_path_rate in [0.0, 0.5]
15+
m = ResMLP(mode; drop_path_rate)
16+
@test size(m(x_224)) == (1000, 1)
17+
@test gradtest(m, x_224)
18+
_gc()
19+
end
20+
end
2121
end
2222

2323
@testset "gMLP" begin

test/runtests.jl

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,16 +29,16 @@ function gradtest(model, input)
2929
end
3030

3131
function normalize_imagenet(data)
32-
cmean = reshape(Float32[0.485, 0.456, 0.406],(1,1,3,1))
33-
cstd = reshape(Float32[0.229, 0.224, 0.225],(1,1,3,1))
32+
cmean = reshape(Float32[0.485, 0.456, 0.406], (1, 1, 3, 1))
33+
cstd = reshape(Float32[0.229, 0.224, 0.225], (1, 1, 3, 1))
3434
return (data .- cmean) ./ cstd
3535
end
3636

3737
# test image
3838
const TEST_PATH = download("https://cdn.pixabay.com/photo/2015/05/07/11/02/guitar-756326_960_720.jpg")
3939
const TEST_IMG = imresize(Images.load(TEST_PATH), (224, 224))
4040
# CHW -> WHC
41-
const TEST_X = permutedims(convert(Array{Float32}, channelview(TEST_IMG)), (3,2,1)) |> normalize_imagenet
41+
const TEST_X = permutedims(convert(Array{Float32}, channelview(TEST_IMG)), (3, 2, 1)) |> normalize_imagenet
4242

4343
# image net labels
4444
const TEST_LBLS = readlines(download("https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt"))
@@ -58,17 +58,11 @@ x_256 = rand(Float32, 256, 256, 3, 1)
5858
include("convnets.jl")
5959
end
6060

61-
GC.safepoint()
62-
GC.gc()
63-
6461
# Mixer tests
6562
@testset verbose = true "Mixers" begin
6663
include("mixers.jl")
6764
end
6865

69-
GC.safepoint()
70-
GC.gc()
71-
7266
# ViT tests
7367
@testset verbose = true "ViTs" begin
7468
include("vits.jl")

0 commit comments

Comments
 (0)