Skip to content

Commit 281f8c9

Browse files
Remove usage of global variables in linear and logistic regression tutorial training functions (#2537)
1 parent f5d25e5 commit 281f8c9

File tree

3 files changed

+48
-48
lines changed

3 files changed

+48
-48
lines changed

docs/src/tutorials/linear_regression.md

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ Flux is a pure Julia ML stack that allows you to build predictive models. Here a
66
- Build a model with configurable parameters to make predictions
77
- Iteratively train the model by tweaking the parameters to improve predictions
88
- Verify your model
9-
9+
1010
Under the hood, Flux uses a technique called automatic differentiation to take gradients that help improve predictions. Flux is also fully written in Julia so you can easily replace any layer of Flux with your own code to improve your understanding or satisfy special requirements.
1111

1212
The following page contains a step-by-step walkthrough of the linear regression algorithm in `Julia` using `Flux`! We will start by creating a simple linear regression model for dummy data and then move on to a real dataset. The first part would involve writing some parts of the model on our own, which will later be replaced by `Flux`.
@@ -104,9 +104,9 @@ julia> custom_model(W, b, x)[1], y[1]
104104
It does! But the predictions are way off. We need to train the model to improve the predictions, but before training the model we need to define the loss function. The loss function would ideally output a quantity that we will try to minimize during the entire training process. Here we will use the mean sum squared error loss function.
105105

106106
```jldoctest linear_regression_simple; filter = r"[+-]?([0-9]*[.])?[0-9]+(f[+-]*[0-9])?"
107-
julia> function custom_loss(W, b, x, y)
108-
ŷ = custom_model(W, b, x)
109-
sum((y .- ŷ).^2) / length(x)
107+
julia> function custom_loss(weights, biases, features, labels)
108+
ŷ = custom_model(weights, biases, features)
109+
sum((labels .- ŷ).^2) / length(weights)
110110
end;
111111
112112
julia> custom_loss(W, b, x, y)
@@ -115,7 +115,7 @@ julia> custom_loss(W, b, x, y)
115115

116116
Calling the loss function on our `x`s and `y`s shows how far our predictions (`ŷ`) are from the real labels. More precisely, it calculates the sum of the squares of residuals and divides it by the total number of data points.
117117

118-
We have successfully defined our model and the loss function, but surprisingly, we haven't used `Flux` anywhere till now. Let's see how we can write the same code using `Flux`.
118+
We have successfully defined our model and the loss function, but surprisingly, we haven't used `Flux` anywhere till now. Let's see how we can write the same code using `Flux`.
119119

120120
```jldoctest linear_regression_simple
121121
julia> flux_model = Dense(1 => 1)
@@ -142,9 +142,9 @@ julia> flux_model(x)[1], y[1]
142142
It is! The next step would be defining the loss function using `Flux`'s functions -
143143

144144
```jldoctest linear_regression_simple; filter = r"[+-]?([0-9]*[.])?[0-9]+(f[+-]*[0-9])?"
145-
julia> function flux_loss(flux_model, x, y)
146-
ŷ = flux_model(x)
147-
Flux.mse(ŷ, y)
145+
julia> function flux_loss(flux_model, features, labels)
146+
ŷ = flux_model(features)
147+
Flux.mse(ŷ, labels)
148148
end;
149149
150150
julia> flux_loss(flux_model, x, y)
@@ -214,13 +214,13 @@ The loss went down! This means that we successfully trained our model for one ep
214214
Let's plug our super training logic inside a function and test it again -
215215

216216
```jldoctest linear_regression_simple; filter = r"[+-]?([0-9]*[.])?[0-9]+(f[+-]*[0-9])?"
217-
julia> function train_custom_model()
218-
dLdW, dLdb, _, _ = gradient(custom_loss, W, b, x, y)
219-
@. W = W - 0.1 * dLdW
220-
@. b = b - 0.1 * dLdb
217+
julia> function train_custom_model!(f_loss, weights, biases, features, labels)
218+
dLdW, dLdb, _, _ = gradient(f_loss, weights, biases, features, labels)
219+
@. weights = weights - 0.1 * dLdW
220+
@. biases = biases - 0.1 * dLdb
221221
end;
222222
223-
julia> train_custom_model();
223+
julia> train_custom_model!(custom_loss, W, b, x, y);
224224
225225
julia> W, b, custom_loss(W, b, x, y)
226226
(Float32[2.340657], Float32[0.7516814], 13.64972f0)
@@ -230,7 +230,7 @@ It works, and the loss went down again! This was the second epoch of our trainin
230230

231231
```jldoctest linear_regression_simple; filter = r"[+-]?([0-9]*[.])?[0-9]+(f[+-]*[0-9])?"
232232
julia> for i = 1:40
233-
train_custom_model()
233+
train_custom_model!(custom_loss, W, b, x, y)
234234
end
235235
236236
julia> W, b, custom_loss(W, b, x, y)
@@ -266,7 +266,7 @@ julia> using Flux, Statistics, MLDatasets, DataFrames
266266
```
267267

268268
## Gathering real data
269-
Let's start by initializing our dataset. We will be using the [`BostonHousing`](https://juliaml.github.io/MLDatasets.jl/stable/datasets/misc/#MLDatasets.BostonHousing) dataset consisting of `506` data points. Each of these data points has `13` features and a corresponding label, the house's price. The `x`s are still mapped to a single `y`, but now, a single `x` data point has 13 features.
269+
Let's start by initializing our dataset. We will be using the [`BostonHousing`](https://juliaml.github.io/MLDatasets.jl/stable/datasets/misc/#MLDatasets.BostonHousing) dataset consisting of `506` data points. Each of these data points has `13` features and a corresponding label, the house's price. The `x`s are still mapped to a single `y`, but now, a single `x` data point has 13 features.
270270

271271
```jldoctest linear_regression_complex
272272
julia> dataset = BostonHousing();
@@ -314,9 +314,9 @@ Dense(13 => 1) # 14 parameters
314314
Same as before, our next step would be to define a loss function to quantify our accuracy somehow. The lower the loss, the better the model!
315315

316316
```jldoctest linear_regression_complex; filter = r"[+-]?([0-9]*[.])?[0-9]+(f[+-]*[0-9])?"
317-
julia> function loss(model, x, y)
318-
ŷ = model(x)
319-
Flux.mse(ŷ, y)
317+
julia> function loss(model, features, labels)
318+
ŷ = model(features)
319+
Flux.mse(ŷ, labels)
320320
end;
321321
322322
julia> loss(model, x_train_n, y_train)
@@ -330,8 +330,8 @@ We can now proceed to the training phase!
330330
The training procedure would make use of the same mathematics, but now we can pass in the model inside the `gradient` call and let `Flux` and `Zygote` handle the derivatives!
331331

332332
```jldoctest linear_regression_complex
333-
julia> function train_model()
334-
dLdm, _, _ = gradient(loss, model, x_train_n, y_train)
333+
julia> function train_model!(f_loss, model, features, labels)
334+
dLdm, _, _ = gradient(f_loss, model, features, labels)
335335
@. model.weight = model.weight - 0.000001 * dLdm.weight
336336
@. model.bias = model.bias - 0.000001 * dLdm.bias
337337
end;
@@ -344,7 +344,7 @@ We can write such custom training loops effortlessly using `Flux` and plain `Jul
344344
julia> loss_init = Inf;
345345
346346
julia> while true
347-
train_model()
347+
train_model!(loss, model, x_train_n, y_train)
348348
if loss_init == Inf
349349
loss_init = loss(model, x_train_n, y_train)
350350
continue
@@ -385,9 +385,9 @@ The loss is not as small as the loss of the training data, but it looks good! Th
385385

386386
---
387387

388-
Summarising this tutorial, we started by generating a random yet correlated dataset for our `custom model`. We then saw how a simple linear regression model could be built with and without `Flux`, and how they were almost identical.
388+
Summarising this tutorial, we started by generating a random yet correlated dataset for our `custom model`. We then saw how a simple linear regression model could be built with and without `Flux`, and how they were almost identical.
389389

390-
Next, we trained the model by manually writing down the Gradient Descent algorithm and optimising the loss. We also saw how `Flux` provides various wrapper functionalities and keeps the API extremely intuitive and simple for the users.
390+
Next, we trained the model by manually writing down the Gradient Descent algorithm and optimising the loss. We also saw how `Flux` provides various wrapper functionalities and keeps the API extremely intuitive and simple for the users.
391391

392392
After getting familiar with the basics of `Flux` and `Julia`, we moved ahead to build a machine learning model for a real dataset. We repeated the exact same steps, but this time with a lot more features and data points, and by harnessing `Flux`'s full capabilities. In the end, we developed a training loop that was smarter than the hardcoded one and ran the model on our normalised dataset to conclude the tutorial.
393393

docs/src/tutorials/logistic_regression.md

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Logistic Regression
22

3-
The following page contains a step-by-step walkthrough of the logistic regression algorithm in Julia using Flux. We will then create a simple logistic regression model without any usage of Flux and compare the different working parts with Flux's implementation.
3+
The following page contains a step-by-step walkthrough of the logistic regression algorithm in Julia using Flux. We will then create a simple logistic regression model without any usage of Flux and compare the different working parts with Flux's implementation.
44

55
Let's start by importing the required Julia packages.
66

@@ -9,7 +9,7 @@ julia> using Flux, Statistics, MLDatasets, DataFrames, OneHotArrays
99
```
1010

1111
## Dataset
12-
Let's start by importing a dataset from MLDatasets.jl. We will use the `Iris` dataset that contains the data of three different `Iris` species. The data consists of 150 data points (`x`s), each having four features. Each of these `x` is mapped to `y`, the name of a particular `Iris` specie. The following code will download the `Iris` dataset when run for the first time.
12+
Let's start by importing a dataset from MLDatasets.jl. We will use the `Iris` dataset that contains the data of three different `Iris` species. The data consists of 150 data points (`x`s), each having four features. Each of these `x` is mapped to a label (or target) `y`, the name of a particular `Iris` species. The following code will download the `Iris` dataset when run for the first time.
1313

1414
```jldoctest logistic_regression
1515
julia> Iris()
@@ -141,7 +141,7 @@ julia> flux_model = Chain(Dense(4 => 3), softmax)
141141
Chain(
142142
Dense(4 => 3), # 15 parameters
143143
softmax,
144-
)
144+
)
145145
```
146146

147147
A [`Dense(4 => 3)`](@ref Dense) layer denotes a layer with four inputs (four features in every data point) and three outputs (three classes or labels). This layer is the same as the mathematical model defined by us above. Under the hood, Flux too calculates the output using the same expression, but we don't have to initialize the parameters ourselves this time, instead Flux does it for us.
@@ -170,9 +170,9 @@ julia> custom_logitcrossentropy(ŷ, y) = mean(.-sum(y .* logsoftmax(ŷ; dims = 1
170170
Now we can wrap the `custom_logitcrossentropy` inside a function that takes in the model parameters, `x`s, and `y`s, and returns the loss value.
171171

172172
```jldoctest logistic_regression; filter = r"[+-]?([0-9]*[.])?[0-9]+(f[+-]*[0-9])?"
173-
julia> function custom_loss(W, b, x, y)
174-
ŷ = custom_model(W, b, x)
175-
custom_logitcrossentropy(ŷ, y)
173+
julia> function custom_loss(weights, biases, features, labels_onehot)
174+
ŷ = custom_model(weights, biases, features)
175+
custom_logitcrossentropy(ŷ, labels_onehot)
176176
end;
177177
178178
julia> custom_loss(W, b, x, custom_y_onehot)
@@ -184,9 +184,9 @@ The loss function works!
184184
Flux provides us with many minimal yet elegant loss functions. In fact, the `custom_logitcrossentropy` defined above has been taken directly from Flux. The functions present in Flux includes sanity checks, ensures efficient performance, and behaves well with the overall FluxML ecosystem.
185185

186186
```jldoctest logistic_regression; filter = r"[+-]?([0-9]*[.])?[0-9]+(f[+-]*[0-9])?"
187-
julia> function flux_loss(flux_model, x, y)
188-
ŷ = flux_model(x)
189-
Flux.logitcrossentropy(ŷ, y)
187+
julia> function flux_loss(flux_model, features, labels_onehot)
188+
ŷ = flux_model(features)
189+
Flux.logitcrossentropy(ŷ, labels_onehot)
190190
end;
191191
192192
julia> flux_loss(flux_model, x, flux_y_onehot)
@@ -214,9 +214,9 @@ julia> max_idx = [x[1] for x in argmax(custom_y_onehot; dims=1)]
214214
Now we can write a function that calculates the indices of the maximum element in each column, and maps them to a class name.
215215

216216
```jldoctest logistic_regression
217-
julia> function custom_onecold(custom_y_onehot)
218-
max_idx = [x[1] for x in argmax(custom_y_onehot; dims=1)]
219-
vec(classes[max_idx])
217+
julia> function custom_onecold(labels_onehot)
218+
max_idx = [x[1] for x in argmax(labels_onehot; dims=1)]
219+
return vec(classes[max_idx])
220220
end;
221221
222222
julia> custom_onecold(custom_y_onehot)
@@ -313,21 +313,21 @@ julia> custom_loss(W, b, x, custom_y_onehot)
313313
The loss went down! Let's plug our super training logic inside a function.
314314

315315
```jldoctest logistic_regression
316-
julia> function train_custom_model()
317-
dLdW, dLdb, _, _ = gradient(custom_loss, W, b, x, custom_y_onehot)
318-
W .= W .- 0.1 .* dLdW
319-
b .= b .- 0.1 .* dLdb
316+
julia> function train_custom_model!(f_loss, weights, biases, features, labels_onehot)
317+
dLdW, dLdb, _, _ = gradient(f_loss, weights, biases, features, labels_onehot)
318+
weights .= weights .- 0.1 .* dLdW
319+
biases .= biases .- 0.1 .* dLdb
320320
end;
321321
```
322322

323323
We can plug the training function inside a loop and train the model for more epochs. The loop can be tailored to suit the user's needs, and the conditions can be specified in plain Julia. Here we will train the model for a maximum of `500` epochs, but to ensure that the model does not overfit, we will break as soon as our accuracy value crosses or becomes equal to `0.98`.
324324

325325
```jldoctest logistic_regression; filter = r"[+-]?([0-9]*[.])?[0-9]+(f[+-]*[0-9])?"
326326
julia> for i = 1:500
327-
train_custom_model();
327+
train_custom_model!(custom_loss, W, b, x, custom_y_onehot);
328328
custom_accuracy(W, b, x, y) >= 0.98 && break
329329
end
330-
330+
331331
julia> @show custom_accuracy(W, b, x, y);
332332
custom_accuracy(W, b, x, y) = 0.98
333333
```
@@ -347,14 +347,14 @@ We can write a similar-looking training loop for our `flux_model` and train it s
347347
julia> flux_loss(flux_model, x, flux_y_onehot)
348348
1.215731131385928
349349
350-
julia> function train_flux_model()
351-
dLdm, _, _ = gradient(flux_loss, flux_model, x, flux_y_onehot)
352-
@. flux_model[1].weight = flux_model[1].weight - 0.1 * dLdm[:layers][1][:weight]
353-
@. flux_model[1].bias = flux_model[1].bias - 0.1 * dLdm[:layers][1][:bias]
350+
julia> function train_flux_model!(f_loss, model, features, labels_onehot)
351+
dLdm, _, _ = gradient(f_loss, model, features, labels_onehot)
352+
@. model[1].weight = model[1].weight - 0.1 * dLdm[:layers][1][:weight]
353+
@. model[1].bias = model[1].bias - 0.1 * dLdm[:layers][1][:bias]
354354
end;
355355
356356
julia> for i = 1:500
357-
train_flux_model();
357+
train_flux_model!(flux_loss, flux_model, x, flux_y_onehot);
358358
flux_accuracy(x, y) >= 0.98 && break
359359
end
360360
```

src/gradient.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ function gradient(f, args...; zero::Bool=true)
3737
end
3838
if Zygote.isderiving()
3939
error("""`Flux.gradient` does not support use within a Zygote gradient.
40-
If what you are doing worked on Flux < 0.14, then calling `Zygote.gradiet` directly should still work.
40+
If what you are doing worked on Flux < 0.14, then calling `Zygote.gradient` directly should still work.
4141
If you are writing new code, then Zygote over Zygote is heavily discouraged.
4242
""")
4343
end
@@ -175,7 +175,7 @@ function withgradient(f, args...; zero::Bool=true)
175175
end
176176
if Zygote.isderiving()
177177
error("""`Flux.withgradient` does not support use within a Zygote gradient.
178-
If what you are doing worked on Flux < 0.14, then calling `Zygote.gradiet` directly should still work.
178+
If what you are doing worked on Flux < 0.14, then calling `Zygote.withgradient` directly should still work.
179179
If you are writing new code, then Zygote over Zygote is heavily discouraged.
180180
""")
181181
end

0 commit comments

Comments
 (0)