Skip to content

Commit fafb8b9

Browse files
authored
🚀 Add simple-optimizers example (#2)
1 parent c43f861 commit fafb8b9

File tree

11 files changed

+396
-112
lines changed

11 files changed

+396
-112
lines changed

Justfile

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,3 +53,8 @@ fmt:
5353
fmt-check:
5454
find crates -type f \( -name '*.cpp' -o -name '*.hpp' \) -exec clang-format --dry-run --Werror {} +
5555

56+
# -------------------------
57+
# Book
58+
# -------------------------
59+
serve-book:
60+
mdbook serve book

book/src/SUMMARY.md

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
11
- [Introduction](introduction.md)
22
- [Kalman filter](kf_linear.md)
3-
- [Class definition]()
4-
- [Class implementation]()
5-
- [Python bindings]()
3+
- [Simple optimizers](simple_optimizers.md)

book/src/simple_optimizers.md

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
# Optimizers
2+
3+
This chapter documents the small optimization module used in the project: a minimal runtime‑polymorphic interface `Optimizer` with two concrete implementations, Gradient Descent and Momentum. It is designed for clarity and easy swapping of algorithms in training loops.
4+
5+
6+
## Problem setting
7+
8+
Given parameters $\mathbf{w}\in\mathbb{R}^d$ and a loss $\mathcal{L}(\mathbf{w})$, an optimizer updates weights using the gradient
9+
$$
10+
\mathbf{g}_t=\nabla_{\mathbf{w}}\mathcal{L}(\mathbf{w}_t).
11+
$$
12+
Each algorithm defines an update rule $\mathbf{w}_{t+1} = \Phi(\mathbf{w}_t,\mathbf{g}_t,\theta)$ with hyper‑parameters $\theta$ (e.g., learning rate, momentum).
13+
14+
15+
## API overview
16+
17+
<details>
18+
<summary>Click here to view the full implementation: <b>include/cppx/opt/optimizers.hpp</b>. We break into down in the sequel of this section. </summary>
19+
20+
```cpp
21+
{{#include ../../crates/simple_optimizers/include/optimizers.hpp}}
22+
```
23+
</details>
24+
25+
Design choices
26+
- A small virtual interface to enable swapping algorithms at runtime.
27+
- `std::unique_ptr<Optimizer>` for owning polymorphism; borrowing functions accept `Optimizer&`.
28+
- Exceptions (`std::invalid_argument`) signal size mismatches.
29+
30+
31+
## Gradient descent
32+
33+
Update rule
34+
$$
35+
\mathbf{w}_{t+1}=\mathbf{w}_{t}-\eta\,\mathbf{g}_t ,
36+
$$
37+
with learning rate $\eta>0$.
38+
39+
Implementation
40+
```cpp
41+
void GradientDescent::step(std::vector<double>& w,
42+
const std::vector<double>& g) {
43+
if (w.size() != g.size()) throw std::invalid_argument("size mismatch");
44+
for (std::size_t i = 0; i < w.size(); ++i) {
45+
w[i] -= lr_ * g[i];
46+
}
47+
}
48+
```
49+
50+
## Momentum-based gradient descent
51+
52+
Update rule
53+
$$
54+
\begin{aligned}
55+
\mathbf{v}_{t+1} &= \mu\,\mathbf{v}_{t} + \eta\,\mathbf{g}_t, \\\\
56+
\mathbf{w}_{t+1} &= \mathbf{w}_{t} - \mathbf{v}_{t+1},
57+
\end{aligned}
58+
$$
59+
with momentum $\mu\in[0,1)$ and learning rate $\eta>0$.
60+
61+
Implementation
62+
```cpp
63+
Momentum::Momentum(double learning_rate, double momentum, std::size_t dim)
64+
: lr_(learning_rate), mu_(momentum), v_(dim, 0.0) {}
65+
66+
void Momentum::step(std::vector<double>& w, const std::vector<double>& g) {
67+
if (w.size() != g.size()) throw std::invalid_argument("size mismatch");
68+
if (v_.size() != w.size()) throw std::invalid_argument("velocity size mismatch");
69+
70+
for (std::size_t i = 0; i < w.size(); ++i) {
71+
v_[i] = mu_ * v_[i] + lr_ * g[i];
72+
w[i] -= v_[i];
73+
}
74+
}
75+
```
76+
77+
## Using the optimizers
78+
79+
### Owning an optimizer (runtime polymorphism)
80+
81+
```cpp
82+
#include <memory>
83+
#include "cppx/opt/optimizers.hpp"
84+
85+
using namespace cppx::opt;
86+
87+
std::vector<double> w(d, 0.0), g(d, 0.0);
88+
89+
// Choose an algorithm at runtime:
90+
std::unique_ptr<Optimizer> opt =
91+
std::make_unique<Momentum>(/*lr=*/0.1, /*mu=*/0.9, /*dim=*/w.size());
92+
93+
for (int epoch = 0; epoch < 100; ++epoch) {
94+
// ... compute gradients into g ...
95+
opt->step(w, g); // updates w in place
96+
}
97+
```
98+
99+
### Borrowing an optimizer (no ownership transfer)
100+
101+
```cpp
102+
void train_one_epoch(Optimizer& opt,
103+
std::vector<double>& w,
104+
std::vector<double>& g) {
105+
// ... fill g ...
106+
opt.step(w, g);
107+
}
108+
```
109+
110+
### API variations (optional)
111+
112+
If C++20 is available, `std::span` can make the interface container‑agnostic:
113+
114+
```cpp
115+
// virtual void step(std::span<double> w, std::span<const double> g) = 0;
116+
```

crates/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
1-
add_subdirectory("kf_linear")
1+
add_subdirectory("kf_linear")
2+
add_subdirectory("simple_optimizers")
Lines changed: 45 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
#pragma once
22

33
#include <Eigen/Dense>
4+
#include <iostream>
45
#include <optional>
5-
#include <vector>
66
#include <stdexcept>
7-
#include <iostream>
7+
#include <vector>
88

99
/**
1010
* @brief Generic linear Kalman filter (templated, no control term).
@@ -17,43 +17,34 @@
1717
* Nx = state dimension (int or Eigen::Dynamic)
1818
* Ny = measurement dimension(int or Eigen::Dynamic)
1919
*/
20-
template<int Nx, int Ny>
21-
class KFLinear {
22-
public:
20+
template <int Nx, int Ny> class KFLinear {
21+
public:
2322
using StateVec = Eigen::Matrix<double, Nx, 1>;
2423
using StateMat = Eigen::Matrix<double, Nx, Nx>;
25-
using MeasVec = Eigen::Matrix<double, Ny, 1>;
26-
using MeasMat = Eigen::Matrix<double, Ny, Ny>;
27-
using ObsMat = Eigen::Matrix<double, Ny, Nx>;
24+
using MeasVec = Eigen::Matrix<double, Ny, 1>;
25+
using MeasMat = Eigen::Matrix<double, Ny, Ny>;
26+
using ObsMat = Eigen::Matrix<double, Ny, Nx>;
2827

2928
/// Construct filter with initial condition and model matrices.
30-
KFLinear(const StateVec& initial_state,
31-
const StateMat& initial_covariance,
32-
const StateMat& transition_matrix,
33-
const ObsMat& observation_matrix,
34-
const StateMat& process_covariance,
35-
const MeasMat& measurement_covariance)
36-
: x_(initial_state),
37-
P_(initial_covariance),
38-
A_(transition_matrix),
39-
H_(observation_matrix),
40-
Q_(process_covariance),
41-
R_(measurement_covariance)
42-
{
43-
std::cout << "DEBUG" << std::endl;
44-
const auto n = x_.rows();
45-
if (A_.rows() != n || A_.cols() != n)
46-
throw std::invalid_argument("A must be n×n and match x dimension");
47-
if (P_.rows() != n || P_.cols() != n)
48-
throw std::invalid_argument("P must be n×n");
49-
if (Q_.rows() != n || Q_.cols() != n)
50-
throw std::invalid_argument("Q must be n×n");
51-
if (H_.cols() != n)
52-
throw std::invalid_argument("H must have n columns");
53-
const auto m = H_.rows();
54-
if (R_.rows() != m || R_.cols() != m)
55-
throw std::invalid_argument("R must be m×m with m = H.rows()");
56-
}
29+
KFLinear(const StateVec &initial_state, const StateMat &initial_covariance,
30+
const StateMat &transition_matrix, const ObsMat &observation_matrix,
31+
const StateMat &process_covariance, const MeasMat &measurement_covariance)
32+
: x_(initial_state), P_(initial_covariance), A_(transition_matrix), H_(observation_matrix),
33+
Q_(process_covariance), R_(measurement_covariance) {
34+
std::cout << "DEBUG" << std::endl;
35+
const auto n = x_.rows();
36+
if (A_.rows() != n || A_.cols() != n)
37+
throw std::invalid_argument("A must be n×n and match x dimension");
38+
if (P_.rows() != n || P_.cols() != n)
39+
throw std::invalid_argument("P must be n×n");
40+
if (Q_.rows() != n || Q_.cols() != n)
41+
throw std::invalid_argument("Q must be n×n");
42+
if (H_.cols() != n)
43+
throw std::invalid_argument("H must have n columns");
44+
const auto m = H_.rows();
45+
if (R_.rows() != m || R_.cols() != m)
46+
throw std::invalid_argument("R must be m×m with m = H.rows()");
47+
}
5748

5849
/// Predict step (no control).
5950
void predict() {
@@ -62,7 +53,7 @@ class KFLinear {
6253
}
6354

6455
/// Update step with a measurement z.
65-
void update(const MeasVec& z) {
56+
void update(const MeasVec &z) {
6657
// Innovation
6758
MeasVec nu = z - H_ * x_;
6859

@@ -76,9 +67,8 @@ class KFLinear {
7667
}
7768

7869
// K = P H^T S^{-1} via solve: S * (K^T) = (P H^T)^T
79-
const auto PHt = P_ * H_.transpose(); // (Nx × Ny)
80-
Eigen::Matrix<double, Nx, Ny> K =
81-
ldlt.solve(PHt.transpose()).transpose(); // (Nx × Ny)
70+
const auto PHt = P_ * H_.transpose(); // (Nx × Ny)
71+
Eigen::Matrix<double, Nx, Ny> K = ldlt.solve(PHt.transpose()).transpose(); // (Nx × Ny)
8272

8373
// State update
8474
x_ += K * nu;
@@ -92,42 +82,42 @@ class KFLinear {
9282
}
9383

9484
/// One full step: predict then (optionally) update.
95-
void step(const std::optional<MeasVec>& measurement) {
85+
void step(const std::optional<MeasVec> &measurement) {
9686
predict();
9787
if (measurement) {
9888
update(*measurement);
9989
}
10090
}
10191

10292
/// Run over a sequence of (optional) measurements.
103-
std::vector<StateVec> filter(const std::vector<std::optional<MeasVec>>& measurements) {
93+
std::vector<StateVec> filter(const std::vector<std::optional<MeasVec>> &measurements) {
10494
std::vector<StateVec> out;
10595
out.reserve(measurements.size());
106-
for (const auto& z : measurements) {
96+
for (const auto &z : measurements) {
10797
step(z);
10898
out.push_back(x_);
10999
}
110100
return out;
111101
}
112102

113103
// Accessors
114-
[[nodiscard]] const StateVec& state() const { return x_; }
115-
[[nodiscard]] const StateMat& covariance() const { return P_; }
104+
[[nodiscard]] const StateVec &state() const { return x_; }
105+
[[nodiscard]] const StateMat &covariance() const { return P_; }
116106

117107
// (Optional) setters if you want to tweak model online
118-
void set_transition(const StateMat& A) { A_ = A; }
119-
void set_observation(const ObsMat& H) { H_ = H; }
120-
void set_process_noise(const StateMat& Q) { Q_ = Q; }
121-
void set_measurement_noise(const MeasMat& R) { R_ = R; }
108+
void set_transition(const StateMat &A) { A_ = A; }
109+
void set_observation(const ObsMat &H) { H_ = H; }
110+
void set_process_noise(const StateMat &Q) { Q_ = Q; }
111+
void set_measurement_noise(const MeasMat &R) { R_ = R; }
122112

123-
private:
113+
private:
124114
// Model
125-
StateMat A_; ///< State transition
126-
ObsMat H_; ///< Observation
127-
StateMat Q_; ///< Process noise covariance
128-
MeasMat R_; ///< Measurement noise covariance
115+
StateMat A_; ///< State transition
116+
ObsMat H_; ///< Observation
117+
StateMat Q_; ///< Process noise covariance
118+
MeasMat R_; ///< Measurement noise covariance
129119

130120
// Estimates
131-
StateVec x_; ///< State mean
132-
StateMat P_; ///< State covariance
121+
StateVec x_; ///< State mean
122+
StateMat P_; ///< State covariance
133123
};

0 commit comments

Comments
 (0)