Skip to content

Commit 47988ef

Browse files
committed
Check top-5 for accuracy instead
1 parent dd0bb6d commit 47988ef

File tree

1 file changed

+15
-14
lines changed

1 file changed

+15
-14
lines changed

test/runtests.jl

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,15 @@ using Flux: Zygote
44
using Images
55

66
const PRETRAINED_MODELS = [
7-
(VGG, 11, false),
8-
(VGG, 13, false),
9-
(VGG, 16, false),
10-
(VGG, 19, false),
11-
(ResNet, 18),
12-
(ResNet, 34),
13-
(ResNet, 50),
14-
(ResNet, 101),
15-
(ResNet, 152),
7+
(VGG, 11, false),
8+
(VGG, 13, false),
9+
(VGG, 16, false),
10+
(VGG, 19, false),
11+
(ResNet, 18),
12+
(ResNet, 34),
13+
(ResNet, 50),
14+
(ResNet, 101),
15+
(ResNet, 152),
1616
]
1717

1818
function gradtest(model, input)
@@ -24,9 +24,9 @@ function gradtest(model, input)
2424
end
2525

2626
function normalize(data)
27-
cmean = reshape(Float32[0.485, 0.456, 0.406],(1,1,3,1))
28-
cstd = reshape(Float32[0.229, 0.224, 0.225],(1,1,3,1))
29-
return (data .- cmean) ./ cstd
27+
cmean = reshape(Float32[0.485, 0.456, 0.406],(1,1,3,1))
28+
cstd = reshape(Float32[0.229, 0.224, 0.225],(1,1,3,1))
29+
return (data .- cmean) ./ cstd
3030
end
3131

3232
# test image
@@ -39,9 +39,10 @@ const TEST_X = permutedims(convert(Array{Float32}, channelview(TEST_IMG)), (3,2,
3939
const TEST_LBLS = readlines(download("https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt"))
4040

4141
function acctest(model)
42-
ypred = Flux.onecold(model(TEST_X), TEST_LBLS)
42+
ypred = model(TEST_X) |> vec
43+
top5 = TEST_LBLS[sortperm(ypred; rev = true)]
4344

44-
return only(ypred) == "acoustic guitar"
45+
return "acoustic guitar" top5
4546
end
4647

4748
x_224 = rand(Float32, 224, 224, 3, 1)

0 commit comments

Comments
 (0)