Skip to content

Commit 3c87993

Browse files
committed
Lecture 10: Polishing touches
1 parent 2d9e973 commit 3c87993

File tree

6 files changed

+165
-176
lines changed

6 files changed

+165
-176
lines changed

docs/src/lecture_10/data/iris.bson

-6.21 KB
Binary file not shown.

docs/src/lecture_10/exercises.md

Lines changed: 54 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -3,65 +3,54 @@ using BSON
33
using Flux
44
using MLDatasets
55
using DataFrames
6-
6+
using Plots
77
using Flux: onehotbatch, onecold, flatten
88
99
Core.eval(Main, :(using Flux)) # hide
1010
ENV["DATADEPS_ALWAYS_ACCEPT"] = true
1111
MNIST.traindata()
1212
13-
function reshape_data(X::AbstractArray{T, 3}, y::AbstractVector) where T
13+
function reshape_data(X::AbstractArray{<:Real, 3})
1414
s = size(X)
15-
return reshape(X, s[1], s[2], 1, s[3]), reshape(y, 1, :)
15+
return reshape(X, s[1], s[2], 1, s[3])
1616
end
1717
18-
function train_or_load!(file_name, m, X, y; force=false, kwargs...)
19-
18+
function train_or_load!(file_name, m, args...; force=false, kwargs...)
19+
2020
!isdir(dirname(file_name)) && mkpath(dirname(file_name))
2121
2222
if force || !isfile(file_name)
23-
train_model!(m, X, y; file_name=file_name, kwargs...)
23+
train_model!(m, args...; file_name=file_name, kwargs...)
2424
else
25-
m_loaded = BSON.load(file_name)[:m]
26-
Flux.loadparams!(m, params(m_loaded))
25+
m_weights = BSON.load(file_name)[:m]
26+
Flux.loadparams!(m, params(m_weights))
2727
end
2828
end
2929
3030
function load_data(dataset; T=Float32, onehot=false, classes=0:9)
31-
X_train, y_train = reshape_data(dataset.traindata(T)...)
32-
X_test, y_test = reshape_data(dataset.testdata(T)...)
33-
y_train = T.(y_train)
34-
y_test = T.(y_test)
31+
X_train, y_train = dataset.traindata(T)
32+
X_test, y_test = dataset.testdata(T)
33+
34+
X_train = reshape_data(X_train)
35+
X_test = reshape_data(X_test)
3536
3637
if onehot
37-
y_train = onehotbatch(y_train[:], classes)
38-
y_test = onehotbatch(y_test[:], classes)
38+
y_train = onehotbatch(y_train, classes)
39+
y_test = onehotbatch(y_test, classes)
3940
end
4041
4142
return X_train, y_train, X_test, y_test
4243
end
4344
44-
using Plots
45-
46-
plot_image(x::AbstractArray{T, 2}) where T = plot(Gray.(1 .-x'), axis=false, ticks=false)
47-
48-
function plot_image(x::AbstractArray{T, 3}) where T
49-
size(x,3) == 1 || error("Image is not grayscale.")
50-
plot_image(x[:,:,1])
51-
end
52-
53-
5445
T = Float32
55-
dataset = MLDatasets.MNIST
56-
57-
X_train, y_train, X_test, y_test = load_data(dataset; T=T, onehot=true)
46+
X_train, y_train, X_test, y_test = load_data(MLDatasets.MNIST; T=T, onehot=true)
5847
```
5948

6049

6150

6251
# Exercises
6352

64-
The first two exercises handle training neural networks on GPUs instead of CPUs. Even though this is extremely important for reducing the training time, we postponed it to the exercises because some course participants may not have a compatible GPU for training. If you are not able to do these two exercises for this reason, we apologize.
53+
The first two exercises handle training neural networks on GPUs instead of CPUs. Even though this is extremely important for reducing the training time, we postponed it to the exercises because some course participants may not have a compatible GPU for training. If anyone is not able to do these two exercises, we apologize.
6554

6655

6756
```@raw html
@@ -70,7 +59,7 @@ The first two exercises handle training neural networks on GPUs instead of CPUs.
7059
```
7160
While most computer operations are performed on CPUs (central processing unit), neural networks are trained on other hardware such as GPUs (graphics processing unit) or specialized hardware such as TPUs.
7261

73-
To use GPUs, include packages Flux and CUDA. Then generate a random matrix ``A\in \mathbb{R}^{100\times 100}`` and a random vector ``b\in \mathbb{R}^{100}``. They will be stored in the memory (RAM) and the computation will be performed on CPU. To move them to the GPU memory and allow computations on GPU, use ```gpu(A)``` or the more commonly used ```A |> gpu```.
62+
To use GPUs, include packages Flux and CUDA. Then generate a random matrix ``A\in \mathbb{R}^{100\times 100}`` and a random vector ``b\in \mathbb{R}^{100}``. They will be stored in the memory (RAM), and the computation will be performed on CPU. To move them to the GPU memory and allow computations on GPU, use ```gpu(A)``` or the more commonly used ```A |> gpu```.
7463

7564
Investigate how long it takes to perform multiplication ``Ab`` if both objects are on CPU, GPU or if they are saved differently. Check that both multiplications resulted in the same vector.
7665
```@raw html
@@ -99,7 +88,7 @@ To test the time, we measure the time for multiplication
9988
0.806913 seconds (419.70 k allocations: 22.046 MiB)
10089
0.709140 seconds (720.01 k allocations: 34.860 MiB, 1.53% gc time)
10190
```
102-
We see that all three times are different. Can we infer anything from it? No! The problem is that during a first call to a function, some compilation usually takes place. We should always compare only the second time.
91+
We see that all three times are different. Can we infer anything from it? No! The problem is that during the first call to a function, some compilation usually takes place. We should always compare only the second time.
10392
```julia
10493
@time A*b;
10594
@time A_g*b_g;
@@ -110,7 +99,7 @@ We see that all three times are different. Can we infer anything from it? No! Th
11099
0.000154 seconds (11 allocations: 272 bytes)
111100
0.475280 seconds (10.20 k allocations: 957.125 KiB)
112101
```
113-
We conclude that while the computation on CPU and GPU takes approximately the same time, when using the mixed types, it takes much longer.
102+
We conclude that while the computation on CPU and GPU takes approximately the same time, it takes much longer when using the mixed types.
114103

115104
To compare the results, the first idea would be to run
116105
```julia
@@ -144,7 +133,7 @@ we realize that one of the arrays is stored in ```Float64``` while the second on
144133

145134

146135

147-
The previous exercise did not show any differences when performing a matrix-vector multiplication. The probable reason was that the running times were too short. The next exercise shows the time difference when applied to a larger problem.
136+
The previous exercise did not show any differences when performing a matrix-vector multiplication. The probable reason was that the running times were too short. The following exercise shows the time difference when applied to a larger problem.
148137

149138

150139

@@ -179,12 +168,12 @@ m = Chain(
179168
)
180169

181170
file_name = joinpath("data", "mnist.bson")
182-
train_or_load!(file_name, m, X_train, y_train)
171+
train_or_load!(file_name, m)
183172

184173
m_g = m |> gpu
185174
X_test_g = X_test |> gpu
186175
```
187-
Now we can measure the evaluation time. Remember that before doing so, we need to compile all the functions by evaluating at least one sample.
176+
Now we can measure the evaluation time. Remember that we need to compile all the functions by evaluating at least one sample before doing so.
188177
```julia
189178
m(X_test[:,:,:,1:1])
190179
m_g(X_test_g[:,:,:,1:1])
@@ -264,7 +253,7 @@ m = Chain(
264253
)
265254
266255
file_name = joinpath("data", "mnist.bson")
267-
train_or_load!(file_name, m, X_train, y_train)
256+
train_or_load!(file_name, m)
268257
```
269258
When creating a table, we specify that its entries are ```Int```. We save the predictions ```y_hat``` and labels ```y```. Since we do not use the second argument to ```onecold```, the entries of ```y_hat``` and ```y``` are between 1 and 10. Then we run a for loop over all misclassified samples and add to the error counts.
270259
```@example gpuu
@@ -315,15 +304,18 @@ Plot all images which are ``9`` but were classified as ``7``.
315304
<details class = "solution-body">
316305
<summary class = "solution-header">Solution:</summary><p>
317306
```
318-
To plot all these misclassified images, we find their indices and use the function ```plot_image```. Since ```y``` are stored in the 1:10 format, we need to shift the indices by one. Since there are 11 of these images, and since 11 is a prime number, we cannot plot it in a ```layout```. We use a hack and add an empty plot ```p_empty```. When plotting, we specify ```layout``` and to minimize the empty space between images also ```size```.
307+
308+
To plot all these misclassified images, we find their indices and use the function `imageplot`. Since `y` are stored in the 1:10 format, we need to specify `classes`.
309+
319310
```@example gpuu
320-
i1 = 9
321-
i2 = 7
311+
using ImageInspector
322312
323-
p = [plot_image(X_test[:,:,:,i]) for i in findall((y.==i1+1) .& (y_hat.==i2+1))]
324-
p_empty = plot(legend=false,grid=false,foreground_color_subplot=:white)
313+
classes = 0:9
325314
326-
plot(p..., p_empty; layout=(3,4), size=(800,600))
315+
targets = onecold(y_test, classes)
316+
predicts = onecold(m(X_test), classes)
317+
318+
imageplot(1 .- X_test, findall((targets .== 9) .& (predicts .== 7)); nrows=3)
327319
328320
savefig("miss.svg") # hide
329321
```
@@ -380,29 +372,40 @@ m = Chain(
380372
)
381373
382374
file_name = joinpath("data", "mnist_sigmoid.bson")
383-
train_or_load!(file_name, m, X_train, y_train)
375+
train_or_load!(file_name, m)
384376
```
385-
Before plotting, we perform a for loop over the digits. Then ```onecold(y_train, classes) .== i``` creates a ```BitArray``` with ones if the condition is satisfied, and zeros if the condition is not satisfied. Then ```findall(???)``` selects all ones, and ```???[1:5]``` finds the first five indices. Since we need to plot the original image, and the images after the second and fourth layer (there is always a convolutional layer before the pooling layer), we save these values into ```z1```, ```z2``` and ```z3```. Since ```plot_image(z1[:,:,1,i])``` plots the first channel of the ``i^{\rm th}`` samples from ```z1```, we create an array of plots by ```p1 = [plot_image(z1[:,:,1,i]) for i in 1:size(z1,4)]```. As the length of ```z1``` is five, the length of ```p1``` is also five. This is the first row of the final plot. We create the other rows in the same way. To plot the final plot, we do ```plot(p1..., p2a..., p2b..., p3a..., p3b...)```, which unpacks the 5 arrays into 25 inputs to the ```plot``` function.
377+
378+
Before plotting, we perform a for loop over the digits. Then ```onecold(y_train, classes) .== i``` creates a ```BitArray``` with ones if the condition is satisfied, and zeros if the condition is not satisfied. Then ```findall(???)``` selects all ones, and ```???[1:5]``` finds the first five indices. Since we need to plot the original image, and the images after the second and fourth layer (there is always a convolutional layer before the pooling layer), we save these values into ```z1```, ```z2``` and ```z3```. Then we need to access to desired channels and plot then via the `ImageInspector` package.
379+
386380
```@example gpuu
381+
using ImageInspector
382+
387383
classes = 0:9
384+
plts = []
388385
for i in classes
389-
ii = findall(onecold(y_train, classes) .== i)[1:5]
386+
jj = 1:5
387+
ii = findall(onecold(y_train, classes) .== i)[jj]
390388
391389
z1 = X_train[:,:,:,ii]
392390
z2 = m[1:2](X_train[:,:,:,ii])
393391
z3 = m[1:4](X_train[:,:,:,ii])
394392
395-
p1 = [plot_image(z1[:,:,1,i]) for i in 1:size(z1,4)]
396-
p2a = [plot_image(z2[:,:,1,i]) for i in 1:size(z2,4)]
397-
p3a = [plot_image(z3[:,:,1,i]) for i in 1:size(z3,4)]
398-
p2b = [plot_image(z2[:,:,end,i]) for i in 1:size(z2,4)]
399-
p3b = [plot_image(z3[:,:,end,i]) for i in 1:size(z3,4)]
400-
401-
plot(p1..., p2a..., p2b..., p3a..., p3b...; layout=(5,5), size=(600,600))
393+
kwargs = (nrows = 1, size = (600, 140))
394+
plot(
395+
imageplot(1 .- z1[:, :, 1, :], jj; kwargs...),
396+
imageplot(1 .- z2[:, :, 1, :], jj; kwargs...),
397+
imageplot(1 .- z2[:, :, end, :], jj; kwargs...),
398+
imageplot(1 .- z3[:, :, 1, :], jj; kwargs...),
399+
imageplot(1 .- z3[:, :, end, :], jj; kwargs...);
400+
layout = (5,1),
401+
size=(700,800)
402+
)
402403
savefig("Layers_$(i).svg")
403404
end
404405
```
406+
405407
We plot and comment on three selected digits below.
408+
406409
```@raw html
407410
</p></details>
408411
```

0 commit comments

Comments
 (0)