Skip to content

Commit f20d02e

Browse files
author
Documenter.jl
committed
build based on 220278c
1 parent deffbaa commit f20d02e

File tree

139 files changed

+21233
-21089
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

139 files changed

+21233
-21089
lines changed

dev/.documenter-siteinfo.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
{"documenter":{"julia_version":"1.10.5","generation_timestamp":"2024-10-01T14:37:26","documenter_version":"1.7.0"}}
1+
{"documenter":{"julia_version":"1.10.5","generation_timestamp":"2024-10-27T13:07:03","documenter_version":"1.7.0"}}

dev/index.html

Lines changed: 2 additions & 2 deletions
Large diffs are not rendered by default.

dev/installation/installation/index.html

Lines changed: 2 additions & 2 deletions
Large diffs are not rendered by default.

dev/installation/tutorial/index.html

Lines changed: 1 addition & 1 deletion
Large diffs are not rendered by default.

dev/lecture_01/arrays/index.html

Lines changed: 2 additions & 2 deletions
Large diffs are not rendered by default.

dev/lecture_01/data_structures/index.html

Lines changed: 2 additions & 2 deletions
Large diffs are not rendered by default.

dev/lecture_01/operators/index.html

Lines changed: 2 additions & 2 deletions
Large diffs are not rendered by default.

dev/lecture_01/strings/index.html

Lines changed: 2 additions & 2 deletions
Large diffs are not rendered by default.

dev/lecture_01/variables/index.html

Lines changed: 2 additions & 2 deletions
Large diffs are not rendered by default.

dev/lecture_02/conditions/index.html

Lines changed: 2 additions & 2 deletions
Large diffs are not rendered by default.

dev/lecture_02/exercises/index.html

Lines changed: 2 additions & 2 deletions
Large diffs are not rendered by default.

dev/lecture_02/functions/index.html

Lines changed: 2 additions & 2 deletions
Large diffs are not rendered by default.

dev/lecture_02/loops/index.html

Lines changed: 2 additions & 2 deletions
Large diffs are not rendered by default.

dev/lecture_02/scope/index.html

Lines changed: 2 additions & 2 deletions
Large diffs are not rendered by default.

dev/lecture_02/sin.svg

Lines changed: 38 additions & 38 deletions

dev/lecture_03/DataFrames/a7756259.svg renamed to dev/lecture_03/DataFrames/33387195.svg

Lines changed: 1303 additions & 1303 deletions

dev/lecture_03/DataFrames/10bf9388.svg renamed to dev/lecture_03/DataFrames/95802852.svg

Lines changed: 51 additions & 51 deletions

dev/lecture_03/DataFrames/46846c54.svg renamed to dev/lecture_03/DataFrames/cd8a48ad.svg

Lines changed: 181 additions & 181 deletions

dev/lecture_03/DataFrames/index.html

Lines changed: 13 additions & 13 deletions
Large diffs are not rendered by default.

dev/lecture_03/Gamma_cdf.svg

Lines changed: 38 additions & 38 deletions

dev/lecture_03/Plots/65119d7a.svg renamed to dev/lecture_03/Plots/132e139f.svg

Lines changed: 135 additions & 135 deletions

dev/lecture_03/Plots/08b0a84d.svg renamed to dev/lecture_03/Plots/1cae0458.svg

Lines changed: 135 additions & 135 deletions

dev/lecture_03/Plots/bf42ffb3.svg renamed to dev/lecture_03/Plots/385d28ec.svg

Lines changed: 34 additions & 34 deletions

dev/lecture_03/Plots/16078691.svg renamed to dev/lecture_03/Plots/42afd179.svg

Lines changed: 36 additions & 36 deletions

dev/lecture_03/Plots/1c1efb99.svg renamed to dev/lecture_03/Plots/63113afd.svg

Lines changed: 28 additions & 28 deletions

dev/lecture_03/Plots/5be9e6cd.svg renamed to dev/lecture_03/Plots/6b2f1b01.svg

Lines changed: 384 additions & 384 deletions

dev/lecture_03/Plots/a182aa60.svg renamed to dev/lecture_03/Plots/6da55696.svg

Lines changed: 90 additions & 90 deletions

dev/lecture_03/Plots/ffcfdf24.svg renamed to dev/lecture_03/Plots/6e71ea37.svg

Lines changed: 28 additions & 28 deletions

dev/lecture_03/Plots/487f177b.svg renamed to dev/lecture_03/Plots/70b2003b.svg

Lines changed: 38 additions & 38 deletions

dev/lecture_03/Plots/9a98b0b6.svg renamed to dev/lecture_03/Plots/9b8264ce.svg

Lines changed: 117 additions & 117 deletions

dev/lecture_03/Plots/5d061910.svg renamed to dev/lecture_03/Plots/a0ddf56e.svg

Lines changed: 40 additions & 40 deletions

dev/lecture_03/Plots/e7396e2d.svg renamed to dev/lecture_03/Plots/c66b02fc.svg

Lines changed: 40 additions & 40 deletions

dev/lecture_03/Plots/8b88b2cf.svg renamed to dev/lecture_03/Plots/c98c920d.svg

Lines changed: 431 additions & 431 deletions

dev/lecture_03/Plots/e51bdfbf.svg renamed to dev/lecture_03/Plots/d8ea0035.svg

Lines changed: 38 additions & 38 deletions

dev/lecture_03/Plots/a8ca44dd.svg renamed to dev/lecture_03/Plots/e5564108.svg

Lines changed: 232 additions & 232 deletions

dev/lecture_03/Plots/9244d45d.svg renamed to dev/lecture_03/Plots/eddbaf52.svg

Lines changed: 117 additions & 117 deletions

dev/lecture_03/Plots/index.html

Lines changed: 13 additions & 13 deletions
Large diffs are not rendered by default.

dev/lecture_03/dataframe.csv

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
A,B,C
2-
5,M,0.7779143593875743
3-
2,F,0.5792191895468449
4-
3,F,0.0005709130583591016
5-
4,M,0.4552062639856772
2+
5,M,0.28163543799864654
3+
2,F,0.06663919293071718
4+
3,F,0.2877039560566471
5+
4,M,0.36552725097757577

dev/lecture_03/interaction/index.html

Lines changed: 2 additions & 2 deletions
Large diffs are not rendered by default.

dev/lecture_03/otherpackages/cf86de8f.svg renamed to dev/lecture_03/otherpackages/0962606d.svg

Lines changed: 38 additions & 38 deletions

dev/lecture_03/otherpackages/feae94e6.svg renamed to dev/lecture_03/otherpackages/8d0b7eaf.svg

Lines changed: 65 additions & 65 deletions

dev/lecture_03/otherpackages/8f118df2.svg

Lines changed: 0 additions & 160 deletions
This file was deleted.

dev/lecture_03/otherpackages/2b325d62.svg renamed to dev/lecture_03/otherpackages/9958c395.svg

Lines changed: 40 additions & 40 deletions

dev/lecture_03/otherpackages/b690d6e4.svg

Lines changed: 163 additions & 0 deletions

dev/lecture_03/otherpackages/index.html

Lines changed: 12 additions & 12 deletions
Large diffs are not rendered by default.

dev/lecture_03/pkg/index.html

Lines changed: 2 additions & 2 deletions
Large diffs are not rendered by default.

dev/lecture_03/plot_exercise1.svg

Lines changed: 1004 additions & 1004 deletions

dev/lecture_03/plot_exercise2.svg

Lines changed: 6 additions & 6 deletions

dev/lecture_03/standardlibrary/index.html

Lines changed: 2 additions & 2 deletions
Large diffs are not rendered by default.

dev/lecture_04/Plots.svg

Lines changed: 32 additions & 32 deletions

dev/lecture_04/exceptions/index.html

Lines changed: 2 additions & 2 deletions
Large diffs are not rendered by default.

dev/lecture_04/exercises/index.html

Lines changed: 2 additions & 2 deletions
Large diffs are not rendered by default.

dev/lecture_04/functions/index.html

Lines changed: 2 additions & 2 deletions
Large diffs are not rendered by default.

dev/lecture_04/gauss.svg

Lines changed: 40 additions & 40 deletions

dev/lecture_04/methods/index.html

Lines changed: 2 additions & 2 deletions
Large diffs are not rendered by default.

dev/lecture_04/scope/index.html

Lines changed: 2 additions & 2 deletions
Large diffs are not rendered by default.

dev/lecture_05/compositetypes/index.html

Lines changed: 2 additions & 2 deletions
Large diffs are not rendered by default.

dev/lecture_05/currencies/index.html

Lines changed: 2 additions & 2 deletions
Large diffs are not rendered by default.

dev/lecture_06/compatibility/index.html

Lines changed: 78 additions & 0 deletions
Large diffs are not rendered by default.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.

dev/lecture_06_07/modules/index.html renamed to dev/lecture_06/modules/index.html

Lines changed: 2 additions & 2 deletions
Large diffs are not rendered by default.

dev/lecture_06/structure/index.html

Lines changed: 45 additions & 0 deletions
Large diffs are not rendered by default.

dev/lecture_06/workflow/index.html

Lines changed: 42 additions & 0 deletions
Large diffs are not rendered by default.

dev/lecture_06_07/develop/index.html

Lines changed: 0 additions & 373 deletions
This file was deleted.

dev/lecture_07/documentation/index.html

Lines changed: 116 additions & 0 deletions
Large diffs are not rendered by default.

dev/lecture_07/extensions/index.html

Lines changed: 93 additions & 0 deletions
Large diffs are not rendered by default.
File renamed without changes.
File renamed without changes.

dev/lecture_07/image_8.svg

Lines changed: 13 additions & 0 deletions

dev/lecture_07/tests/index.html

Lines changed: 119 additions & 0 deletions
Large diffs are not rendered by default.

dev/lecture_08/constrained/index.html

Lines changed: 2 additions & 2 deletions
Large diffs are not rendered by default.

dev/lecture_08/exercises/index.html

Lines changed: 2 additions & 2 deletions
Large diffs are not rendered by default.

dev/lecture_08/grad1.svg

Lines changed: 91 additions & 91 deletions

dev/lecture_08/grad2.svg

Lines changed: 28 additions & 28 deletions

dev/lecture_08/grad3.svg

Lines changed: 1291 additions & 1291 deletions

dev/lecture_08/gradients/109ed296.svg renamed to dev/lecture_08/gradients/5f66e55c.svg

Lines changed: 91 additions & 91 deletions

dev/lecture_08/gradients/index.html

Lines changed: 3 additions & 3 deletions
Large diffs are not rendered by default.

dev/lecture_08/obj.svg

Lines changed: 28 additions & 28 deletions

dev/lecture_08/theory/index.html

Lines changed: 2 additions & 2 deletions
Large diffs are not rendered by default.

dev/lecture_08/unconstrained/index.html

Lines changed: 2 additions & 2 deletions
Large diffs are not rendered by default.

dev/lecture_09/exercises/index.html

Lines changed: 2 additions & 2 deletions
Large diffs are not rendered by default.

dev/lecture_09/iris1.svg

Lines changed: 132 additions & 132 deletions

dev/lecture_09/iris2.svg

Lines changed: 132 additions & 132 deletions

dev/lecture_09/iris_lin1.svg

Lines changed: 181 additions & 181 deletions

dev/lecture_09/iris_lin2.svg

Lines changed: 188 additions & 188 deletions

dev/lecture_09/linear/index.html

Lines changed: 2 additions & 2 deletions
Large diffs are not rendered by default.

dev/lecture_09/logistic/index.html

Lines changed: 2 additions & 2 deletions
Large diffs are not rendered by default.

dev/lecture_09/sigmoid.svg

Lines changed: 28 additions & 28 deletions

dev/lecture_09/theory/index.html

Lines changed: 1 addition & 1 deletion
Large diffs are not rendered by default.

dev/lecture_10/Activation.svg

Lines changed: 111 additions & 111 deletions

dev/lecture_10/Overfit.svg

Lines changed: 44 additions & 44 deletions

dev/lecture_10/Separation.svg

Lines changed: 42 additions & 42 deletions

dev/lecture_10/Separation2.svg

Lines changed: 42 additions & 42 deletions

dev/lecture_10/Separation3.svg

Lines changed: 132 additions & 132 deletions

dev/lecture_10/Train_test.svg

Lines changed: 38 additions & 38 deletions

dev/lecture_10/exercises/index.html

Lines changed: 2 additions & 2 deletions
Large diffs are not rendered by default.

dev/lecture_10/loss.svg

Lines changed: 28 additions & 28 deletions

dev/lecture_10/nn/index.html

Lines changed: 2 additions & 2 deletions
Large diffs are not rendered by default.

dev/lecture_10/theory/index.html

Lines changed: 2 additions & 2 deletions
Large diffs are not rendered by default.

dev/lecture_11/Iris_train_test_acc.svg

Lines changed: 36 additions & 36 deletions

dev/lecture_11/data/mnist.bson

-20.2 KB
Binary file not shown.

dev/lecture_11/data/mnist.jl

Lines changed: 27 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -1,108 +1,54 @@
1-
using BSON
21
using Flux
3-
using Flux: onehotbatch, onecold
2+
using Flux: onecold
43
using MLDatasets
54

6-
Core.eval(Main, :(using Flux)) # hide
7-
8-
function reshape_data(X::AbstractArray{T, 3}, y::AbstractVector) where T
9-
s = size(X)
10-
return reshape(X, s[1], s[2], 1, s[3]), reshape(y, 1, :)
11-
end
12-
13-
function train_or_load!(file_name, m, X, y; force=false, kwargs...)
14-
15-
!isdir(dirname(file_name)) && mkpath(dirname(file_name))
16-
17-
if force || !isfile(file_name)
18-
train_model!(m, X, y; file_name=file_name, kwargs...)
19-
else
20-
m_loaded = BSON.load(file_name)[:m]
21-
Flux.loadparams!(m, params(m_loaded))
22-
end
23-
end
24-
25-
function load_data(dataset; onehot=false, T=Float32)
26-
classes = 0:9
27-
X_train, y_train = reshape_data(dataset(T, :train)[:]...)
28-
X_test, y_test = reshape_data(dataset(T, :test)[:]...)
29-
y_train = T.(y_train)
30-
y_test = T.(y_test)
31-
32-
if onehot
33-
y_train = onehotbatch(y_train[:], classes)
34-
y_test = onehotbatch(y_test[:], classes)
35-
end
36-
37-
return X_train, y_train, X_test, y_test
38-
end
39-
40-
using Plots
41-
42-
plot_image(x::AbstractArray{T, 2}) where T = plot(Gray.(x'), axis=nothing)
43-
44-
function plot_image(x::AbstractArray{T, 4}) where T
45-
@assert size(x,4) == 1
46-
plot_image(x[:,:,:,1])
47-
end
48-
49-
function plot_image(x::AbstractArray{T, 3}) where T
50-
@assert size(x,3) == 1
51-
plot_image(x[:,:,1])
52-
end
53-
5+
include(joinpath(dirname(@__FILE__), "utilities.jl"))
546

557
T = Float32
568
dataset = MLDatasets.MNIST
579

5810
X_train, y_train, X_test, y_test = load_data(dataset; T=T, onehot=true)
5911

12+
model = Chain(
13+
Conv((2, 2), 1 => 16, sigmoid),
14+
MaxPool((2, 2)),
15+
Conv((2, 2), 16 => 8, sigmoid),
16+
MaxPool((2, 2)),
17+
Flux.flatten,
18+
Dense(288, size(y_train, 1)),
19+
softmax,
20+
)
6021

61-
62-
63-
64-
m = Chain(
65-
Conv((2,2), 1=>16, sigmoid),
66-
MaxPool((2,2)),
67-
Conv((2,2), 16=>8, sigmoid),
68-
MaxPool((2,2)),
69-
flatten,
70-
Dense(288, size(y_train,1)), softmax)
71-
72-
file_name = joinpath("data", "mnist_sigmoid.bson")
73-
train_or_load!(file_name, m, X_train, y_train)
74-
75-
76-
22+
file_name = joinpath("data", "mnist_sigmoid.jld2")
23+
train_or_load!(file_name, model, X_train, y_train)
7724

7825
ii1 = findall(onecold(y_train, 0:9) .== 1)[1:5]
7926
ii2 = findall(onecold(y_train, 0:9) .== 9)[1:5]
8027

81-
8228
for qwe = 0:9
8329
ii0 = findall(onecold(y_train, 0:9) .== qwe)[1:5]
8430

85-
p0 = [plot_image(X_train[:,:,:,i:i][:,:,1,1]) for i in ii0]
86-
p1 = [plot_image((m[1:2](X_train[:,:,:,i:i]))[:,:,1,1]) for i in ii0]
87-
p2 = [plot_image((m[1:4](X_train[:,:,:,i:i]))[:,:,1,1]) for i in ii0]
31+
p0 = [plot_image(X_train[:, :, :, i:i][:, :, 1, 1]) for i in ii0]
32+
p1 = [plot_image((model[1:2](X_train[:, :, :, i:i]))[:, :, 1, 1]) for i in ii0]
33+
p2 = [plot_image((model[1:4](X_train[:, :, :, i:i]))[:, :, 1, 1]) for i in ii0]
8834

89-
p = plot(p0..., p1..., p2...; layout=(3,5))
35+
p = plot(p0..., p1..., p2...; layout=(3, 5))
9036
display(p)
9137
end
9238

93-
p0 = [plot_image(X_train[:,:,:,i:i][:,:,1,1]) for i in ii1]
94-
p1 = [plot_image((m[1:2](X_train[:,:,:,i:i]))[:,:,1,1]) for i in ii1]
95-
p2 = [plot_image((m[1:4](X_train[:,:,:,i:i]))[:,:,1,1]) for i in ii1]
39+
p0 = [plot_image(X_train[:, :, :, i:i][:, :, 1, 1]) for i in ii1]
40+
p1 = [plot_image((model[1:2](X_train[:, :, :, i:i]))[:, :, 1, 1]) for i in ii1]
41+
p2 = [plot_image((model[1:4](X_train[:, :, :, i:i]))[:, :, 1, 1]) for i in ii1]
9642

97-
plot(p0..., p1..., p2...; layout=(3,5))
43+
plot(p0..., p1..., p2...; layout=(3, 5))
9844

9945

100-
p0 = [plot_image(X_train[:,:,:,i:i][:,:,1,1]) for i in ii2]
101-
p1 = [plot_image((m[1:2](X_train[:,:,:,i:i]))[:,:,1,1]) for i in ii2]
102-
p2 = [plot_image((m[1:4](X_train[:,:,:,i:i]))[:,:,1,1]) for i in ii2]
46+
p0 = [plot_image(X_train[:, :, :, i:i][:, :, 1, 1]) for i in ii2]
47+
p1 = [plot_image((model[1:2](X_train[:, :, :, i:i]))[:, :, 1, 1]) for i in ii2]
48+
p2 = [plot_image((model[1:4](X_train[:, :, :, i:i]))[:, :, 1, 1]) for i in ii2]
10349

104-
plot(p0..., p1..., p2...; layout=(3,5))
50+
plot(p0..., p1..., p2...; layout=(3, 5))
10551

106-
for i in 1:length(m)
107-
println(size(m[1:i](X_train[:,:,:,1:1])))
52+
for i in 1:length(model)
53+
println(size(model[1:i](X_train[:, :, :, 1:1])))
10854
end

dev/lecture_11/data/mnist.jld2

25.8 KB
Binary file not shown.

dev/lecture_11/data/mnist_gpu.jl

Lines changed: 13 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1,69 +1,23 @@
11
using MLDatasets
22
using Flux
3-
using BSON
4-
using Random
5-
using Statistics
6-
using Base.Iterators: partition
7-
using Flux: crossentropy, onehotbatch, onecold
83

9-
10-
accuracy(x, y) = mean(onecold(cpu(m(x))) .== onecold(cpu(y)))
11-
12-
function reshape_data(X::AbstractArray{T, 3}, y::AbstractVector) where T
13-
s = size(X)
14-
return reshape(X, s[1], s[2], 1, s[3]), reshape(y, 1, :)
15-
end
16-
17-
function train_model!(m, X, y;
18-
opt=ADAM(0.001),
19-
batch_size=128,
20-
n_epochs=10,
21-
file_name="")
22-
23-
loss(x, y) = crossentropy(m(x), y)
24-
25-
batches_train = map(partition(randperm(size(y, 2)), batch_size)) do inds
26-
return (gpu(X[:, :, :, inds]), gpu(y[:, inds]))
27-
end
28-
29-
for i in 1:n_epochs
30-
println("Iteration " * string(i))
31-
Flux.train!(loss, params(m), batches_train, opt)
32-
end
33-
34-
!isempty(file_name) && BSON.bson(file_name, m=m|>cpu)
35-
36-
return
37-
end
38-
39-
function load_data(dataset; onehot=false, T=Float32)
40-
classes = 0:9
41-
X_train, y_train = reshape_data(dataset(T, :train)[:]...)
42-
X_test, y_test = reshape_data(dataset(T, :test)[:]...)
43-
y_train = T.(y_train)
44-
y_test = T.(y_test)
45-
46-
if onehot
47-
y_train = onehotbatch(y_train[:], classes)
48-
y_test = onehotbatch(y_test[:], classes)
49-
end
50-
51-
return X_train, y_train, X_test, y_test
52-
end
4+
include(joinpath(dirname(@__FILE__), "utilities.jl"))
535

546
dataset = MLDatasets.MNIST
557
T = Float32
568
X_train, y_train, X_test, y_test = load_data(dataset; T=T, onehot=true)
579

58-
m = Chain(
59-
Conv((2,2), 1=>16, sigmoid),
60-
MaxPool((2,2)),
61-
Conv((2,2), 16=>8, sigmoid),
62-
MaxPool((2,2)),
63-
flatten,
64-
Dense(288, size(y_train,1)), softmax) |> gpu
10+
model = Chain(
11+
Conv((2, 2), 1 => 16, sigmoid),
12+
MaxPool((2, 2)),
13+
Conv((2, 2), 16 => 8, sigmoid),
14+
MaxPool((2, 2)),
15+
Flux.flatten,
16+
Dense(288, size(y_train, 1)),
17+
softmax,
18+
) |> gpu
6519

66-
file_name = joinpath("data", "mnist_sigmoid.bson")
67-
train_model!(m, X_train, y_train; file_name=file_name, n_epochs=100)
20+
file_name = joinpath("data", "mnist_sigmoid.jld2")
21+
train_model!(model, X_train, y_train; file_name=file_name, n_epochs=100)
6822

69-
accuracy(X_test |> gpu, y_test |> gpu)
23+
accuracy(model, X_test, y_test)
-20.1 KB
Binary file not shown.
25.8 KB
Binary file not shown.

dev/lecture_11/data/utilities.jl

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
using MLDatasets
2+
using Flux
3+
using JLD2
4+
using Random
5+
using Statistics
6+
using Base.Iterators: partition
7+
using Flux: crossentropy, onehotbatch, onecold
8+
using Plots
9+
using Pkg
10+
11+
if haskey(Pkg.project().dependencies, "CUDA")
12+
using CUDA
13+
else
14+
gpu(x) = x
15+
end
16+
17+
accuracy(model, x, y) = mean(onecold(cpu(model(x))) .== onecold(cpu(y)))
18+
19+
function reshape_data(X::AbstractArray{T,3}, y::AbstractVector) where {T}
20+
s = size(X)
21+
return reshape(X, s[1], s[2], 1, s[3]), reshape(y, 1, :)
22+
end
23+
24+
function load_data(dataset; onehot=false, T=Float32)
25+
classes = 0:9
26+
X_train, y_train = reshape_data(dataset(T, :train)[:]...)
27+
X_test, y_test = reshape_data(dataset(T, :test)[:]...)
28+
y_train = T.(y_train)
29+
y_test = T.(y_test)
30+
31+
if onehot
32+
y_train = onehotbatch(y_train[:], classes)
33+
y_test = onehotbatch(y_test[:], classes)
34+
end
35+
36+
return X_train, y_train, X_test, y_test
37+
end
38+
39+
function train_model!(
40+
model,
41+
X,
42+
y;
43+
opt=Adam(0.001),
44+
batch_size=128,
45+
n_epochs=10,
46+
file_name="",
47+
)
48+
49+
loss(x, y) = crossentropy(model(x), y)
50+
51+
batches_train = map(partition(randperm(size(y, 2)), batch_size)) do inds
52+
return (gpu(X[:, :, :, inds]), gpu(y[:, inds]))
53+
end
54+
55+
for epoch in 1:n_epochs
56+
@show epoch
57+
Flux.train!(loss, Flux.params(model), batches_train, opt)
58+
end
59+
60+
!isempty(file_name) && jldsave(file_name; model_state=Flux.state(model) |> cpu)
61+
62+
return
63+
end
64+
65+
function train_or_load!(file_name, model, args...; force=false, kwargs...)
66+
67+
!isdir(dirname(file_name)) && mkpath(dirname(file_name))
68+
69+
if force || !isfile(file_name)
70+
train_model!(model, args...; file_name=file_name, kwargs...)
71+
else
72+
model_state = JLD2.load(file_name, "model_state")
73+
Flux.loadmodel!(model, model_state)
74+
end
75+
end
76+
77+
plot_image(x::AbstractArray{T,2}) where {T} = plot(Gray.(x'), axis=nothing)
78+
79+
function plot_image(x::AbstractArray{T,4}) where {T}
80+
@assert size(x, 4) == 1
81+
plot_image(x[:, :, :, 1])
82+
end
83+
84+
function plot_image(x::AbstractArray{T,3}) where {T}
85+
@assert size(x, 3) == 1
86+
plot_image(x[:, :, 1])
87+
end

0 commit comments

Comments
 (0)