From ada7eda092e55e5790073ee6a63fece48c687105 Mon Sep 17 00:00:00 2001 From: Ranjodh Singh Date: Sun, 19 Oct 2025 10:05:40 +0530 Subject: [PATCH 1/2] Add printInterval to PrintLoss Callback --- .../ensmallen_bits/callbacks/print_loss.hpp | 20 +++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/include/ensmallen_bits/callbacks/print_loss.hpp b/include/ensmallen_bits/callbacks/print_loss.hpp index bc21b90ba..be5a5d3a4 100644 --- a/include/ensmallen_bits/callbacks/print_loss.hpp +++ b/include/ensmallen_bits/callbacks/print_loss.hpp @@ -12,6 +12,8 @@ #ifndef ENSMALLEN_CALLBACKS_PRINT_LOSS_HPP #define ENSMALLEN_CALLBACKS_PRINT_LOSS_HPP +#include + namespace ens { /** @@ -25,8 +27,14 @@ class PrintLoss * * @param ostream Ostream which receives output from this object. */ - PrintLoss(std::ostream& output = arma::get_cout_stream()) : output(output) - { /* Nothing to do here. */ } + PrintLoss(std::ostream& output = arma::get_cout_stream(), + const size_t printInterval = 1) + : output(output), printInterval(printInterval) + { + if (printInterval == 0) + throw std::invalid_argument( + "PrintLoss(): printInterval cannot be zero."); + } /** * Callback function called at the end of a pass over the data. @@ -41,16 +49,20 @@ class PrintLoss bool EndEpoch(OptimizerType& /* optimizer */, FunctionType& /* function */, const MatType& /* coordinates */, - const size_t /* epoch */, + const size_t epoch, const double objective) { - output << objective << std::endl; + // epochs are 1-indexed. + if (epoch % printInterval == 0) + output << objective << std::endl; return false; } private: //! The output stream that all data is to be sent to; example: std::cout. std::ostream& output; + //! The number of iterations in between each print. + const size_t printInterval; }; } // namespace ens From e094b3a1a33d9def5508cf57201af3d7bebf70a9 Mon Sep 17 00:00:00 2001 From: Ranjodh Singh Date: Sun, 19 Oct 2025 10:20:36 +0530 Subject: [PATCH 2/2] Modify Callback Documentation --- doc/callbacks.md | 2 ++ include/ensmallen_bits/callbacks/print_loss.hpp | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/doc/callbacks.md b/doc/callbacks.md index f67912e28..96e0a792f 100644 --- a/doc/callbacks.md +++ b/doc/callbacks.md @@ -249,12 +249,14 @@ Callback that prints loss to stdout or a specified output stream. * `PrintLoss()` * `PrintLoss(`_`output`_`)` + * `PrintLoss(`_`output, printInterval`_`)` #### Attributes | **type** | **name** | **description** | **default** | |----------|----------|-----------------|-------------| | `std::ostream` | **`output`** | Ostream which receives output from this object. | `stdout` | +| `size_t` | **`printInterval`** | The number of epochs between consecutive loss prints. | `1` | #### Examples: diff --git a/include/ensmallen_bits/callbacks/print_loss.hpp b/include/ensmallen_bits/callbacks/print_loss.hpp index be5a5d3a4..fa5c0e27f 100644 --- a/include/ensmallen_bits/callbacks/print_loss.hpp +++ b/include/ensmallen_bits/callbacks/print_loss.hpp @@ -61,7 +61,7 @@ class PrintLoss private: //! The output stream that all data is to be sent to; example: std::cout. std::ostream& output; - //! The number of iterations in between each print. + //! The number of epochs between consecutive loss prints. const size_t printInterval; };