@@ -4,15 +4,15 @@ using Flux: Zygote
4
4
using Images
5
5
6
6
const PRETRAINED_MODELS = [
7
- (VGG, 11 , false ),
8
- (VGG, 13 , false ),
9
- (VGG, 16 , false ),
10
- (VGG, 19 , false ),
11
- (ResNet, 18 ),
12
- (ResNet, 34 ),
13
- (ResNet, 50 ),
14
- (ResNet, 101 ),
15
- (ResNet, 152 ),
7
+ (VGG, 11 , false ),
8
+ (VGG, 13 , false ),
9
+ (VGG, 16 , false ),
10
+ (VGG, 19 , false ),
11
+ (ResNet, 18 ),
12
+ (ResNet, 34 ),
13
+ (ResNet, 50 ),
14
+ (ResNet, 101 ),
15
+ (ResNet, 152 ),
16
16
]
17
17
18
18
function gradtest (model, input)
@@ -24,9 +24,9 @@ function gradtest(model, input)
24
24
end
25
25
26
26
function normalize (data)
27
- cmean = reshape (Float32[0.485 , 0.456 , 0.406 ],(1 ,1 ,3 ,1 ))
28
- cstd = reshape (Float32[0.229 , 0.224 , 0.225 ],(1 ,1 ,3 ,1 ))
29
- return (data .- cmean) ./ cstd
27
+ cmean = reshape (Float32[0.485 , 0.456 , 0.406 ],(1 ,1 ,3 ,1 ))
28
+ cstd = reshape (Float32[0.229 , 0.224 , 0.225 ],(1 ,1 ,3 ,1 ))
29
+ return (data .- cmean) ./ cstd
30
30
end
31
31
32
32
# test image
@@ -39,9 +39,10 @@ const TEST_X = permutedims(convert(Array{Float32}, channelview(TEST_IMG)), (3,2,
39
39
const TEST_LBLS = readlines (download (" https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt" ))
40
40
41
41
function acctest (model)
42
- ypred = Flux. onecold (model (TEST_X), TEST_LBLS)
42
+ ypred = model (TEST_X) |> vec
43
+ top5 = TEST_LBLS[sortperm (ypred; rev = true )]
43
44
44
- return only (ypred) == " acoustic guitar"
45
+ return " acoustic guitar" ∈ top5
45
46
end
46
47
47
48
x_224 = rand (Float32, 224 , 224 , 3 , 1 )
0 commit comments