diff --git a/HISTORY.md b/HISTORY.md index 2c45eac78..cb339a632 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,5 +1,11 @@ ### ensmallen ?.??.?: "???" ###### ????-??-?? +* Refactor `GradientDescent` into + `GradientDescentType`. + Add the `DeltaBarDelta` optimizer, which implements Jacob's delta-bar-delta + update through `GradientDescentType` with `DeltaBarDeltaUpdate` and `NoDecay` + policies. ([#440](https://github.com/mlpack/ensmallen/pull/440)) + See the documentation for more details. ### ensmallen 3.10.0: "Unexpected Rain" ###### 2025-09-25 @@ -44,6 +50,7 @@ ActiveCMAES opt(lambda, BoundaryBoxConstraint(lowerBound, upperBound), ...); ``` + * Add proximal gradient optimizers for L1-constrained and other related problems: `FBS`, `FISTA`, and `FASTA` ([#427](https://github.com/mlpack/ensmallen/pull/427)). See the diff --git a/doc/function_types.md b/doc/function_types.md index de581c6e2..d0db04ace 100644 --- a/doc/function_types.md +++ b/doc/function_types.md @@ -135,6 +135,7 @@ The following optimizers can be used with differentiable functions: * [Fast Adaptive Shrinkage/Thresholding Algorithm (FASTA)](#fast-adaptive-shrinkage-thresholding-algorithm-fasta) (`ens::FASTA`) * [FrankWolfe](#frank-wolfe) (`ens::FrankWolfe`) * [GradientDescent](#gradient-descent) (`ens::GradientDescent`) + * [DeltaBarDelta](#delta-bar-delta) (`ens::DeltaBarDelta`) - Any optimizer for [arbitrary functions](#arbitrary-functions) Each of these optimizers has an `Optimize()` function that is called as diff --git a/doc/optimizers.md b/doc/optimizers.md index aeb57aab6..a3594525a 100644 --- a/doc/optimizers.md +++ b/doc/optimizers.md @@ -823,8 +823,6 @@ parameters. If `lambda` and `sigma` are not specified, then 0 is used as the initial value for all Lagrange multipliers and 10 is used as the initial penalty parameter. - - #### Examples
@@ -1261,6 +1259,62 @@ optimizer.Optimize(f, coordinates); * [Differential Evolution in Wikipedia](https://en.wikipedia.org/wiki/Differential_Evolution) * [Arbitrary functions](#arbitrary-functions) +## DeltaBarDelta + +*An optimizer for [differentiable functions](#differentiable-functions).* + +A Gradient Descent variant that adapts learning rates for each parameter to improve convergence. If the current gradient and the exponential average of past gradients corresponding to a parameter have the same sign, then the step size for that parameter is incremented by `kappa`. Otherwise, it is decreased by a proportion `phi` of its current value (additive increase, multiplicative decrease). + +***Note:*** DeltaBarDelta is very sensitive to its parameters (`kappa` and `phi`) hence a good hyperparameter selection is necessary as its default may not fit every case. Typically, `kappa` should be smaller than the step size. + +#### Constructors + + * `DeltaBarDelta()` + * `DeltaBarDelta(`_`stepSize`_`)` + * `DeltaBarDelta(`_`stepSize, maxIterations, tolerance`_`)` + * `DeltaBarDelta(`_`stepSize, maxIterations, tolerance, kappa, phi, theta, minStepSize, resetPolicy`_`)` + +Note that `DeltaBarDelta` is based on the templated type +`GradientDescentType<`_`UpdatePolicyType, DecayPolicyType`_`>` with _`UpdatePolicyType`_` = +`DeltaBarDeltaUpdate` and _`DecayPolicyType`_` = NoDecay`. + +#### Attributes + +| **type** | **name** | **description** | **default** | +|----------|----------|-----------------|-------------| +| `double` | **`stepSize`** | Initial step size. | `0.01` | +| `size_t` | **`maxIterations`** | Maximum number of iterations allowed (0 means no limit). | `100000` | +| `double` | **`tolerance`** | Maximum absolute tolerance to terminate algorithm. | `1e-5` | +| `double` | **`kappa`** | Additive increase constant for step size when gradient signs persist. | `0.002` | +| `double` | **`phi`** | Multiplicative decrease factor for step size when gradient signs flip. | `0.2` | +| `double` | **`theta`** | Decay rate for computing the exponential average of past gradients. | `0.8` | +| `double` | **`minStepSize`** | Minimum allowed step size for any parameter. | `1e-8` | +| `bool` | **`resetPolicy`** | If true, parameters are reset before every Optimize call. | `true` | + +Attributes of the optimizer may be accessed and modified via member functions of the same name. + +#### Examples: + +
+Click to collapse/expand example code. + + +```c++ +RosenbrockFunction f; +arma::mat coordinates = f.GetInitialPoint(); + +DeltaBarDelta optimizer(0.001, 0, 1e-15, 0.0001, 0.2, 0.8); +optimizer.Optimize(f, coordinates); +``` + +
+ +#### See also: + + * [Increased rates of convergence through learning rate adaptation (pdf)](https://www.academia.edu/download/32005051/Jacobs.NN88.pdf) + * [Differentiable functions](#differentiable-functions) + * [Gradient Descent](#gradient-descent) + ## DemonAdam *An optimizer for [differentiable separable functions](#differentiable-separable-functions).* @@ -1899,6 +1953,11 @@ negative of the gradient of the function at the current point. * `GradientDescent()` * `GradientDescent(`_`stepSize`_`)` * `GradientDescent(`_`stepSize, maxIterations, tolerance`_`)` + * `GradientDescent(`_`stepSize, maxIterations, tolerance, updatePolicy, decayPolicy, resetPolicy`_`)` + +Note that `GradientDescent` is based on the templated type +`GradientDescentType<`_`UpdatePolicyType, DecayPolicyType`_`>` with _`UpdatePolicyType`_` = +VanillaUpdate` and _`DecayPolicyType`_` = NoDecay`. #### Attributes @@ -1907,9 +1966,14 @@ negative of the gradient of the function at the current point. | `double` | **`stepSize`** | Step size for each iteration. | `0.01` | | `size_t` | **`maxIterations`** | Maximum number of iterations allowed (0 means no limit). | `100000` | | `size_t` | **`tolerance`** | Maximum absolute tolerance to terminate algorithm. | `1e-5` | +| `UpdatePolicyType` | **`updatePolicy`** | Instantiated update policy used to adjust the given parameters. | `UpdatePolicyType()` | +| `DecayPolicyType` | **`decayPolicy`** | Instantiated decay policy used to adjust the step size. | `DecayPolicyType()` | +| `bool` | **`resetPolicy`** | Flag that determines whether update policy parameters are reset before every Optimize call. | `true` | Attributes of the optimizer may also be changed via the member methods -`StepSize()`, `MaxIterations()`, and `Tolerance()`. +`StepSize()`, `MaxIterations()`, `Tolerance()`, `UpdatePolicy()`, +`DecayPolicy()`, and `ResetPolicy()`. + #### Examples: diff --git a/include/ensmallen.hpp b/include/ensmallen.hpp index 5b32ca37f..5c88e055f 100644 --- a/include/ensmallen.hpp +++ b/include/ensmallen.hpp @@ -120,6 +120,7 @@ #include "ensmallen_bits/cd/cd.hpp" #include "ensmallen_bits/cne/cne.hpp" #include "ensmallen_bits/de/de.hpp" +#include "ensmallen_bits/delta_bar_delta/delta_bar_delta.hpp" #include "ensmallen_bits/eve/eve.hpp" #include "ensmallen_bits/fasta/fasta.hpp" #include "ensmallen_bits/fbs/fbs.hpp" diff --git a/include/ensmallen_bits/delta_bar_delta/delta_bar_delta.hpp b/include/ensmallen_bits/delta_bar_delta/delta_bar_delta.hpp new file mode 100644 index 000000000..167e21dfe --- /dev/null +++ b/include/ensmallen_bits/delta_bar_delta/delta_bar_delta.hpp @@ -0,0 +1,184 @@ +/** + * @file delta_bar_delta.hpp + * @author Ranjodh Singh + * + * Class wrapper for the DeltaBarDelta update policy. + * + * ensmallen is free software; you may redistribute it and/or modify it under + * the terms of the 3-clause BSD license. You should have received a copy of + * the 3-clause BSD license along with ensmallen. If not, see + * http://www.opensource.org/licenses/BSD-3-Clause for more information. + */ +#ifndef ENSMALLEN_DELTA_BAR_DELTA_HPP +#define ENSMALLEN_DELTA_BAR_DELTA_HPP + +#include +#include "./delta_bar_delta_update.hpp" + +namespace ens { + +/** + * DeltaBarDelta Optimizer. + * + * A heuristic designed to accelerate convergence by + * adapting the learning rate of each parameter individually. + * + * According to the Delta-Bar-Delta update: + * + * - If the current gradient and the exponential average of + * past gradients corresponding to a parameter have the same + * sign, then the step size for that parameter is incremented by + * \f$\kappa\f$. Otherwise, it is decreased by a proportion \f$\phi\f$ + * of its current value (additive increase, multiplicative decrease). + * + * @note This implementation uses a minStepSize parameter to set a lower + * bound for the learning rate. This prevents the learning rate from + * dropping to zero, which can occur due to floating-point underflow. + * For tasks which require extreme fine-tuning, you may need to lower + * this parameter below its default value (1e-8) in order to allow for + * smaller learning rates. + * + * @code + * @article{jacobs1988increased, + * title = {Increased Rates of Convergence Through Learning Rate + * Adaptation}, + * author = {Jacobs, Robert A.}, journal = {Neural Networks}, + * volume = {1}, + * number = {4}, + * pages = {295--307}, + * year = {1988}, + * publisher = {Pergamon} + * } + * @endcode + */ +class DeltaBarDelta +{ + public: + /** + * Construct the DeltaBarDelta optimizer with the given function and + * parameters. DeltaBarDelta is very sensitive to its parameters (kappa + * and phi) hence a good hyperparameter selection is necessary as its + * default may not fit every case. + * + * @param stepSize Initial step size. + * @param maxIterations Maximum number of iterations allowed (0 means no + * limit). + * @param tolerance Maximum absolute tolerance to terminate algorithm. + * @param kappa Constant increment applied when gradient signs persist. + * @param phi Proportional decrement factor when gradient signs flip. + * @param theta Decay rate for the exponential average (delta-bar). + * @param minStepSize Minimum allowed step size for any parameter + * (default: 1e-8). + * @param resetPolicy If true, parameters are reset before every Optimize + * call; otherwise, their values are retained. + */ + DeltaBarDelta(const double stepSize = 0.01, + const size_t maxIterations = 100000, + const double tolerance = 1e-5, + const double kappa = 0.002, + const double phi = 0.2, + const double theta = 0.8, + const double minStepSize = 1e-8, + const bool resetPolicy = true); + + /** + * Optimize the given function using DeltaBarDelta. + * The given starting point will be modified to store the finishing + * point of the algorithm, and the final objective value is returned. + * + * @tparam SeparableFunctionType Type of the function to optimize. + * @tparam MatType Type of matrix to optimize with. + * @tparam GradType Type of matrix to use to represent function gradients. + * @tparam CallbackTypes Types of callback functions. + * @param function Function to optimize. + * @param iterate Starting point (will be modified). + * @param callbacks Callback functions. + * @return Objective value of the final point. + */ + template + typename std::enable_if::value, + typename MatType::elem_type>::type + Optimize(SeparableFunctionType& function, + MatType& iterate, + CallbackTypes&&... callbacks) + { + return optimizer.Optimize(function, iterate, + std::forward(callbacks)...); + } + + //! Forward the MatType as GradType. + template + typename MatType::elem_type Optimize(SeparableFunctionType& function, + MatType& iterate, + CallbackTypes&&... callbacks) + { + return Optimize(function, iterate, + std::forward(callbacks)...); + } + + //! Get the initial step size. + double StepSize() const { return optimizer.StepSize(); } + //! Modify the initial step size. + double& StepSize() { return optimizer.StepSize(); } + + //! Get the maximum number of iterations (0 indicates no limit). + size_t MaxIterations() const { return optimizer.MaxIterations(); } + //! Modify the maximum number of iterations (0 indicates no limit). + size_t& MaxIterations() { return optimizer.MaxIterations(); } + + //! Get the additive increase constant for step size + //! when gradient signs persist. + double Kappa() const { return optimizer.UpdatePolicy().Kappa(); } + //! Modify the additive increase constant for step size + //! when gradient signs persist. + double& Kappa() { return optimizer.UpdatePolicy().Kappa(); } + + //! Get the multiplicative decrease factor for step size + //! when gradient signs flip. + double Phi() const { return optimizer.UpdatePolicy().Phi(); } + //! Get the multiplicative decrease factor for step size + //! when gradient signs flip. + double& Phi() { return optimizer.UpdatePolicy().Phi(); } + + //! Get the decay rate for computing the exponential average + //! of past gradients (delta-bar). + double Theta() const { return optimizer.UpdatePolicy().Theta(); } + //! Modify the decay rate for computing the exponential average + //! of past gradients (delta-bar). + double& Theta() { return optimizer.UpdatePolicy().Theta(); } + + //! Get the minimum allowed step size. + double MinStepSize() const { return optimizer.UpdatePolicy().MinStepSize(); } + //! Modify the minimum allowed step size. + double& MinStepSize() { return optimizer.UpdatePolicy().MinStepSize(); } + + //! Get the tolerance for termination. + double Tolerance() const { return optimizer.Tolerance(); } + //! Modify the tolerance for termination. + double& Tolerance() { return optimizer.Tolerance(); } + + //! Get whether or not the update policy parameters are reset before + //! Optimize call. + bool ResetPolicy() const { return optimizer.ResetPolicy(); } + //! Modify whether or not the update policy parameters + //! are reset before Optimize call. + bool& ResetPolicy() { return optimizer.ResetPolicy(); } + + private: + //! The GradientDescentType object with DeltaBarDelta policy. + GradientDescentType optimizer; +}; + +} // namespace ens + +// Include implementation. +#include "delta_bar_delta_impl.hpp" + +#endif // ENSMALLEN_DELTA_BAR_DELTA_HPP diff --git a/include/ensmallen_bits/delta_bar_delta/delta_bar_delta_impl.hpp b/include/ensmallen_bits/delta_bar_delta/delta_bar_delta_impl.hpp new file mode 100644 index 000000000..9e5db0a09 --- /dev/null +++ b/include/ensmallen_bits/delta_bar_delta/delta_bar_delta_impl.hpp @@ -0,0 +1,41 @@ +/** + * @file delta_bar_delta.hpp + * @author Ranjodh Singh + * + * Implementation of DeltaBarDelta class wrapper. + * + * ensmallen is free software; you may redistribute it and/or modify it under + * the terms of the 3-clause BSD license. You should have received a copy of + * the 3-clause BSD license along with ensmallen. If not, see + * http://www.opensource.org/licenses/BSD-3-Clause for more information. + */ +#ifndef ENSMALLEN_DELTA_BAR_DELTA_IMPL_HPP +#define ENSMALLEN_DELTA_BAR_DELTA_IMPL_HPP + +// In case it hasn't been included yet. +#include "./delta_bar_delta.hpp" + +namespace ens { + +inline DeltaBarDelta::DeltaBarDelta( + const double stepSize, + const size_t maxIterations, + const double tolerance, + const double kappa, + const double phi, + const double theta, + const double minStepSize, + const bool resetPolicy) : + optimizer(stepSize, + maxIterations, + tolerance, + DeltaBarDeltaUpdate(stepSize, kappa, phi, theta, minStepSize), + NoDecay(), + resetPolicy) +{ + /* Nothing to do. */ +} + +} // namespace ens + +#endif // ENSMALLEN_DELTA_BAR_DELTA_IMPL_HPP diff --git a/include/ensmallen_bits/delta_bar_delta/delta_bar_delta_update.hpp b/include/ensmallen_bits/delta_bar_delta/delta_bar_delta_update.hpp new file mode 100644 index 000000000..b0c90ca67 --- /dev/null +++ b/include/ensmallen_bits/delta_bar_delta/delta_bar_delta_update.hpp @@ -0,0 +1,201 @@ +/** + * @file delta_bar_delta_update.hpp + * @author Ranjodh Singh + * + * DeltaBarDelta update policy for Gradient Descent. + * + * ensmallen is free software; you may redistribute it and/or modify it under + * the terms of the 3-clause BSD license. You should have received a copy of + * the 3-clause BSD license along with ensmallen. If not, see + * http://www.opensource.org/licenses/BSD-3-Clause for more information. + */ +#ifndef ENSMALLEN_GRADIENT_DESCENT_DELTA_BAR_DELTA_UPDATE_HPP +#define ENSMALLEN_GRADIENT_DESCENT_DELTA_BAR_DELTA_UPDATE_HPP + +namespace ens { + +/** + * DeltaBarDelta Update Policy for Gradient Descent. + * + * A heuristic designed to accelerate convergence by + * adapting the learning rate of each parameter individually. + * + * According to the Delta-Bar-Delta update: + * + * - If the current gradient and the exponential average of + * past gradients corresponding to a parameter have the same + * sign, then the step size for that parameter is incremented by + * \f$\kappa\f$. Otherwise, it is decreased by a proportion \f$\phi\f$ + * of its current value (additive increase, multiplicative decrease). + * + * @note This implementation uses a minStepSize parameter to set a lower + * bound for the learning rate. This prevents the learning rate from + * dropping to zero, which can occur due to floating-point underflow. + * For tasks which require extreme fine-tuning, you may need to lower + * this parameter below its default value (1e-8) in order to allow for + * smaller learning rates. + * + * @code + * @article{jacobs1988increased, + * title = {Increased Rates of Convergence Through Learning Rate + * Adaptation}, + * author = {Jacobs, Robert A.}, journal = {Neural Networks}, + * volume = {1}, + * number = {4}, + * pages = {295--307}, + * year = {1988}, + * publisher = {Pergamon} + * } + * @endcode + */ +class DeltaBarDeltaUpdate +{ + public: + /** + * Construct the DeltaBarDelta update policy with given parameters. + * + * @param initialStepSize Initial Step Size. + * @param kappa Constant increment applied when gradient signs persist. + * @param phi Proportional decrement factor when gradient signs flip. + * @param theta Decay rate for the exponential average (delta-bar). + * @param minStepSize Minimum allowed step size for any parameter + * (default: 1e-8). + */ + DeltaBarDeltaUpdate( + const double initialStepSize, + const double kappa, + const double phi, + const double theta, + const double minStepSize = 1e-8) : + initialStepSize(initialStepSize), + kappa(kappa), + phi(phi), + theta(theta), + minStepSize(minStepSize) + { + /* Do nothing. */ + } + + //! Access the initialStepSize hyperparameter. + double InitialStepSize() const { return initialStepSize; } + //! Modify the initialStepSize hyperparameter. + double& InitialStepSize() { return initialStepSize; } + + //! Access the kappa hyperparameter. + double Kappa() const { return kappa; } + //! Modify the kappa hyperparameter. + double& Kappa() { return kappa; } + + //! Access the phi hyperparameter. + double Phi() const { return phi; } + //! Modify the phi hyperparameter. + double& Phi() { return phi; } + + //! Access the theta hyperparameter. + double Theta() const { return theta; } + //! Modify the theta hyperparameter. + double& Theta() { return theta; } + + //! Access the minStepSize hyperparameter. + double MinStepSize() const { return minStepSize; } + //! Modify the minStepSize hyperparameter. + double& MinStepSize() { return minStepSize; } + + /** + * The UpdatePolicyType policy classes must contain an internal 'Policy' + * template class with two template arguments: MatType and GradType. This is + * instantiated at the start of the optimization, and holds parameters + * specific to an individual optimization. + */ + template + class Policy + { + public: + typedef typename MatType::elem_type ElemType; + + /** + * This is called by the optimizer method before the start of the iteration + * update process. + * + * @param parent Instantiated parent class. + * @param rows Number of rows in the gradient matrix. + * @param cols Number of columns in the gradient matrix. + */ + Policy( + const DeltaBarDeltaUpdate& parent, + const size_t rows, + const size_t cols) : + parent(parent), + kappa(ElemType(parent.kappa)), + phi(ElemType(parent.phi)), + theta(ElemType(parent.theta)), + minStepSize(ElemType(parent.minStepSize)) + { + delta_bar.zeros(rows, cols); + epsilon.set_size(rows, cols); + epsilon.fill(parent.InitialStepSize()); + } + + /** + * Update step for Gradient Descent. + * + * @param iterate Parameters that minimize the function. + * @param stepSize Step size to be used for the given iteration. + * @param delta The gradient matrix. + */ + void Update(MatType& iterate, + const double stepSize, + const GradType& delta) + { + const MatType signMatrix = sign(delta % delta_bar); + const MatType sameSignMask = conv_to::from(signMatrix == +1); + const MatType diffSignMask = conv_to::from(signMatrix == -1); + + epsilon += sameSignMask * kappa; + epsilon -= diffSignMask * phi % epsilon; + epsilon.clamp(minStepSize, + arma::Datum::inf); + + delta_bar *= theta; + delta_bar += (1 - theta) * delta; + + iterate -= epsilon % delta; + } + + private: + //! The instantiated parent class. + const DeltaBarDeltaUpdate& parent; + + //! The exponential average of past gradients. + MatType delta_bar; + + //! Tracks the current step size for each parameter. + MatType epsilon; + + // Parent parameters converted to the element type of the matrix. + ElemType kappa; + ElemType phi; + ElemType theta; + ElemType minStepSize; + }; + + private: + //! The initialStepSize hyperparameter. + double initialStepSize; + + //! The kappa hyperparameter. + double kappa; + + //! The phi hyperparameter. + double phi; + + //! The theta hyperparameter. + double theta; + + //! The minStepSize hyperparameter. + double minStepSize; +}; + +} // namespace ens + +#endif // ENSMALLEN_GRADIENT_DESCENT_DELTA_BAR_DELTA_UPDATE_HPP diff --git a/include/ensmallen_bits/gradient_descent/gradient_descent.hpp b/include/ensmallen_bits/gradient_descent/gradient_descent.hpp index 024c1bdec..d8686611a 100644 --- a/include/ensmallen_bits/gradient_descent/gradient_descent.hpp +++ b/include/ensmallen_bits/gradient_descent/gradient_descent.hpp @@ -39,8 +39,17 @@ namespace ens { * GradientDescent can optimize differentiable functions. For more details, see * the documentation on function types included with this distribution or on the * ensmallen website. + * + * @tparam UpdatePolicyType Update policy used by Gradient Descent during the + * iterative update process. By default vanilla update policy (see + * ens::VanillaUpdate) is used. + * @tparam DecayPolicyType Decay policy used during the iterative update + * process to adjust the step size. By default the step size isn't going to + * be adjusted (i.e. NoDecay is used). */ -class GradientDescent +template +class GradientDescentType { public: /** @@ -54,10 +63,24 @@ class GradientDescent * @param maxIterations Maximum number of iterations allowed (0 means no * limit). * @param tolerance Maximum absolute tolerance to terminate algorithm. + * @param updatePolicy Instantiated update policy used to adjust the given + * parameters. + * @param decayPolicy Instantiated decay policy used to adjust the step size. + * @param resetPolicy Flag that determines whether update policy parameters + * are reset before every Optimize call. + */ + GradientDescentType( + const double stepSize = 0.01, + const size_t maxIterations = 100000, + const double tolerance = 1e-5, + const UpdatePolicyType& updatePolicy = UpdatePolicyType(), + const DecayPolicyType& decayPolicy = DecayPolicyType(), + const bool resetPolicy = true); + + /** + * Clean any memory associated with the GradientDescent object. */ - GradientDescent(const double stepSize = 0.01, - const size_t maxIterations = 100000, - const double tolerance = 1e-5); + ~GradientDescentType(); /** * Optimize the given function using gradient descent. The given starting @@ -160,6 +183,37 @@ class GradientDescent //! Modify the tolerance for termination. double& Tolerance() { return tolerance; } + //! Get whether or not the update policy parameters + //! are reset before Optimize call. + bool ResetPolicy() const { return resetPolicy; } + //! Modify whether or not the update policy parameters + //! are reset before Optimize call. + bool& ResetPolicy() { return resetPolicy; } + + //! Get the update policy. + const UpdatePolicyType& UpdatePolicy() const { return updatePolicy; } + //! Modify the update policy. + UpdatePolicyType& UpdatePolicy() { return updatePolicy; } + + //! Get the instantiated update policy type. + //! Be sure to check its type with Has() before using! + const Any& InstUpdatePolicy() const { return instUpdatePolicy; } + //! Modify the instantiated update policy type. + //! Be sure to check its type with Has() before using! + Any& InstUpdatePolicy() { return instUpdatePolicy; } + + //! Get the step size decay policy. + const DecayPolicyType& DecayPolicy() const { return decayPolicy; } + //! Modify the step size decay policy. + DecayPolicyType& DecayPolicy() { return decayPolicy; } + + //! Get the instantiated decay policy type. + //! Be sure to check its type with Has() before using! + const Any& InstDecayPolicy() const { return instDecayPolicy; } + //! Modify the instantiated decay policy type. + //! Be sure to check its type with Has() before using! + Any& InstDecayPolicy() { return instDecayPolicy; } + private: //! The step size for each example. double stepSize; @@ -169,8 +223,30 @@ class GradientDescent //! The tolerance for termination. double tolerance; + + //! The update policy used to update the parameters in each iteration. + UpdatePolicyType updatePolicy; + + //! The decay policy used to update the step size. + DecayPolicyType decayPolicy; + + //! Flag indicating whether update policy + //! should be reset before running optimization. + bool resetPolicy; + + //! Flag indicating whether the update policy + //! parameters have been initialized. + bool isInitialized; + + //! The initialized update policy. + Any instUpdatePolicy; + + //! The initialized decay policy. + Any instDecayPolicy; }; +using GradientDescent = GradientDescentType; + } // namespace ens #include "gradient_descent_impl.hpp" diff --git a/include/ensmallen_bits/gradient_descent/gradient_descent_impl.hpp b/include/ensmallen_bits/gradient_descent/gradient_descent_impl.hpp index 813244c14..99bfddc5e 100644 --- a/include/ensmallen_bits/gradient_descent/gradient_descent_impl.hpp +++ b/include/ensmallen_bits/gradient_descent/gradient_descent_impl.hpp @@ -20,25 +20,43 @@ namespace ens { //! Constructor. -inline GradientDescent::GradientDescent( +template +GradientDescentType::GradientDescentType( const double stepSize, const size_t maxIterations, - const double tolerance) : + const double tolerance, + const UpdatePolicyType& updatePolicy, + const DecayPolicyType& decayPolicy, + const bool resetPolicy) : stepSize(stepSize), maxIterations(maxIterations), - tolerance(tolerance) + tolerance(tolerance), + updatePolicy(updatePolicy), + decayPolicy(decayPolicy), + resetPolicy(resetPolicy), + isInitialized(false) { /* Nothing to do. */ } +template +GradientDescentType::~GradientDescentType() +{ + // Clean decay and update policies, if they were initialized. + instDecayPolicy.Clean(); + instUpdatePolicy.Clean(); +} + //! Optimize the function (minimize). +template template typename std::enable_if::value, typename MatType::elem_type>::type -GradientDescent::Optimize(FunctionType& function, - MatType& iterateIn, - CallbackTypes&&... callbacks) +GradientDescentType::Optimize( + FunctionType& function, + MatType& iterateIn, + CallbackTypes&&... callbacks) { // Convenience typedefs. typedef typename MatType::elem_type ElemType; @@ -49,6 +67,13 @@ GradientDescent::Optimize(FunctionType& function, typedef Function FullFunctionType; FullFunctionType& f(static_cast(function)); + // The update policy and decay policy internally use a templated class so + // that we can know MatType and GradType only when Optimize() is called. + typedef typename UpdatePolicyType::template Policy + InstUpdatePolicyType; + typedef typename DecayPolicyType::template Policy + InstDecayPolicyType; + // Make sure we have the methods that we need. traits::CheckFunctionTypeAPI(); RequireFloatingPointType(); @@ -65,6 +90,24 @@ GradientDescent::Optimize(FunctionType& function, // Controls early termination of the optimization process. bool terminate = false; + // Initialize the decay policy if needed. + if (!isInitialized || !instDecayPolicy.Has()) + { + instDecayPolicy.Clean(); + instDecayPolicy.Set( + new InstDecayPolicyType(decayPolicy)); + } + + // Initialize the update policy. + if (resetPolicy || !isInitialized || + !instUpdatePolicy.Has()) + { + instUpdatePolicy.Clean(); + instUpdatePolicy.Set(new InstUpdatePolicyType( + updatePolicy, iterate.n_rows, iterate.n_cols)); + isInitialized = true; + } + // Now iterate! Callback::BeginOptimization(*this, f, iterate, callbacks...); for (size_t i = 1; i != maxIterations && !terminate; ++i) @@ -97,28 +140,40 @@ GradientDescent::Optimize(FunctionType& function, return overallObjective; } - // Reset the counter variables. - lastObjective = overallObjective; + // Use the update policy to take a step. + instUpdatePolicy.As().Update(iterate, + stepSize, + gradient); - // And update the iterate. - iterate -= ElemType(stepSize) * gradient; terminate |= Callback::StepTaken(*this, f, iterate, callbacks...); + + // Now update the learning rate if requested by the user. + instDecayPolicy.As().Update(iterate, + stepSize, + gradient); + + // Reset the counter variables. + lastObjective = overallObjective; } - Info << "Gradient Descent: maximum iterations (" << maxIterations - << ") reached; " << "terminating optimization." << std::endl; + if (!terminate) + { + Info << "Gradient Descent: maximum iterations (" << maxIterations + << ") reached; " << "terminating optimization." << std::endl; + } Callback::EndOptimization(*this, f, iterate, callbacks...); return overallObjective; } +template template typename std::enable_if::value, typename MatType::elem_type>::type -GradientDescent::Optimize( +GradientDescentType::Optimize( FunctionType& function, MatType& iterate, const std::vector& categoricalDimensions, @@ -159,4 +214,4 @@ GradientDescent::Optimize( } // namespace ens -#endif +#endif // ENSMALLEN_GRADIENT_DESCENT_GRADIENT_DESCENT_IMPL_HPP diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 487e6e278..9521ed842 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -15,6 +15,7 @@ set(ENSMALLEN_TESTS_SOURCES cmaes_test.cpp cne_test.cpp de_test.cpp + delta_bar_delta_test.cpp demon_adam_test.cpp demon_sgd_test.cpp eve_test.cpp diff --git a/tests/delta_bar_delta_test.cpp b/tests/delta_bar_delta_test.cpp new file mode 100644 index 000000000..1e7d5e9eb --- /dev/null +++ b/tests/delta_bar_delta_test.cpp @@ -0,0 +1,47 @@ +/** + * @file delta_bar_delta_test.cpp + * @author Ranjodh Singh + * + * ensmallen is free software; you may redistribute it and/or modify it under + * the terms of the 3-clause BSD license. You should have received a copy of + * the 3-clause BSD license along with ensmallen. If not, see + * http://www.opensource.org/licenses/BSD-3-Clause for more information. + */ +#if defined(ENS_USE_COOT) + #include + #include +#endif +#include +#include "catch.hpp" +#include "test_function_tools.hpp" +#include "test_types.hpp" + +using namespace ens; +using namespace ens::test; + +TEMPLATE_TEST_CASE("DeltaBarDelta_GDTestFunction", "[DeltaBarDelta]", + ENS_ALL_TEST_TYPES) +{ + DeltaBarDelta s(0.01, 500, 1e-9, 0.001, 0.2, 0.8); + FunctionTest(s, + Tolerances::LargeObj, + Tolerances::LargeCoord); +} + +TEMPLATE_TEST_CASE("DeltaBarDelta_RosenbrockFunction", "[DeltaBarDelta]", + ENS_ALL_TEST_TYPES) +{ + DeltaBarDelta s(0.001, 0, Tolerances::Obj / 100, + 0.0001, 0.2, 0.8); + FunctionTest(s, + 10 * Tolerances::LargeObj, + 10 * Tolerances::LargeCoord); +} + +TEMPLATE_TEST_CASE("DeltaBarDelta_LogisticRegressionFunction", + "[DeltaBarDelta]", ENS_ALL_TEST_TYPES) +{ + DeltaBarDelta s(0.00032, 32, Tolerances::Obj, + 0.000032, 0.2, 0.8); + LogisticRegressionFunctionTest(s); +}