Skip to content

Commit 058e065

Browse files
committed
Lecture 10: Scripts
1 parent 6e99627 commit 058e065

File tree

3 files changed

+11
-2
lines changed

3 files changed

+11
-2
lines changed

scripts/lecture_10/Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
[deps]
22
BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0"
33
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
4+
ImageInspector = "b0ce21f1-0238-464b-b95f-8a4068743199"
45
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
6+
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
57
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
68
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
79

810
[compat]
911
BSON = "= 0.2.6"
1012
Flux = "= 0.11.6"
1113
MLDatasets = "= 0.5.6"
14+
Plots = "= 1.10.3"
1215
julia = "1.5"

scripts/lecture_10/script_init.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,16 @@ using MLDatasets
1010
using Flux
1111
using Flux: onehotbatch, onecold, crossentropy
1212
using Flux.Data: DataLoader
13+
using Plots
14+
using ImageInspector
1315

1416
function reshape_data(X::AbstractArray{<:Real, 3})
1517
s = size(X)
1618
return reshape(X, s[1], s[2], 1, s[3])
1719
end
1820

21+
reshape_data(X::AbstractArray{<:Real, 4}) = X
22+
1923
function load_data(dataset; T=Float32, onehot=false, classes=0:9)
2024
X_train, y_train = dataset.traindata(T)
2125
X_test, y_test = dataset.testdata(T)
@@ -33,7 +37,7 @@ end
3337

3438
T = Float32
3539
X_train, y_train, X_test, y_test = load_data(MLDatasets.MNIST; T=T, onehot=true);
36-
load_data(MLDatasets.CIFAR; T=T, onehot=true);
40+
load_data(MLDatasets.CIFAR10; T=T, onehot=true);
3741

3842
inds = findall(y_train .== 0)[1:15]
3943
imageplot(1 .- X_train, inds; nrows=3, size=(800,480))
@@ -71,4 +75,4 @@ train_model!(m, L, X_train, y_train; n_epochs=1)
7175

7276
accuracy(x, y) = mean(onecold(m(x)) .== onecold(y))
7377

74-
string(accuracy(X_test, y_test))
78+
accuracy(X_test, y_test)

scripts/lecture_10/script_sol.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ using MLDatasets
66
using Flux
77
using Flux: onehotbatch, onecold, crossentropy
88
using Flux.Data: DataLoader
9+
using Plots
10+
using ImageInspector
911

1012
T = Float32
1113
X_train, y_train = MLDatasets.MNIST.traindata(T)

0 commit comments

Comments
 (0)