Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/callbacks.md
Original file line number Diff line number Diff line change
Expand Up @@ -249,12 +249,14 @@ Callback that prints loss to stdout or a specified output stream.

* `PrintLoss()`
* `PrintLoss(`_`output`_`)`
* `PrintLoss(`_`output, printInterval`_`)`

#### Attributes

| **type** | **name** | **description** | **default** |
|----------|----------|-----------------|-------------|
| `std::ostream` | **`output`** | Ostream which receives output from this object. | `stdout` |
| `size_t` | **`printInterval`** | The number of epochs between consecutive loss prints. | `1` |

#### Examples:

Expand Down
20 changes: 16 additions & 4 deletions include/ensmallen_bits/callbacks/print_loss.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
#ifndef ENSMALLEN_CALLBACKS_PRINT_LOSS_HPP
#define ENSMALLEN_CALLBACKS_PRINT_LOSS_HPP

#include <stdexcept>

namespace ens {

/**
Expand All @@ -25,8 +27,14 @@ class PrintLoss
*
* @param ostream Ostream which receives output from this object.
*/
PrintLoss(std::ostream& output = arma::get_cout_stream()) : output(output)
{ /* Nothing to do here. */ }
PrintLoss(std::ostream& output = arma::get_cout_stream(),
const size_t printInterval = 1)
: output(output), printInterval(printInterval)
{
if (printInterval == 0)
throw std::invalid_argument(
"PrintLoss(): printInterval cannot be zero.");
}

/**
* Callback function called at the end of a pass over the data.
Expand All @@ -41,16 +49,20 @@ class PrintLoss
bool EndEpoch(OptimizerType& /* optimizer */,
FunctionType& /* function */,
const MatType& /* coordinates */,
const size_t /* epoch */,
const size_t epoch,
const double objective)
{
output << objective << std::endl;
// epochs are 1-indexed.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are they? I thought they started at 0. In either case, it doesn't really make a big difference: if they are 0-indexed they will print the first iteration too.

if (epoch % printInterval == 0)
output << objective << std::endl;
return false;
}

private:
//! The output stream that all data is to be sent to; example: std::cout.
std::ostream& output;
//! The number of epochs between consecutive loss prints.
const size_t printInterval;
};

} // namespace ens
Expand Down