-
-
Notifications
You must be signed in to change notification settings - Fork 195
Use perfect forwarding for functions that use apply_*_*
functions
#3215
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 9 commits
4eb5681
81bd2d0
561f772
e40bd1e
3c9462e
5dd0f62
589895d
6301f3e
deec65a
8a824c9
5fffe24
80dbe6c
50052d9
939ed72
e63c866
67a15d5
ebb2dbd
92e4076
3e921d4
43a2e0b
2b9d261
1144241
3c42f90
0f1658d
7958b81
77646e6
8dc16ca
2b52498
b033319
1ebbb3e
7862a22
ea2f224
49a5273
60065f1
54b12e1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -18,20 +18,23 @@ namespace math { | |||||||||||||
* autodiff variable. | ||||||||||||||
*/ | ||||||||||||||
template <typename F, typename T> | ||||||||||||||
struct apply_scalar_unary<F, fvar<T> > { | ||||||||||||||
struct apply_scalar_unary<F, T, require_fvar_t<T>> { | ||||||||||||||
/** | ||||||||||||||
* Function return type, which is same as the argument type for | ||||||||||||||
* the function, <code>fvar<T></code>. | ||||||||||||||
*/ | ||||||||||||||
using return_t = fvar<T>; | ||||||||||||||
using return_t = std::decay_t<T>; | ||||||||||||||
|
||||||||||||||
/** | ||||||||||||||
* Apply the function specified by F to the specified argument. | ||||||||||||||
* | ||||||||||||||
* @param x Argument variable. | ||||||||||||||
* @return Function applied to the variable. | ||||||||||||||
*/ | ||||||||||||||
static inline return_t apply(const fvar<T>& x) { return F::fun(x); } | ||||||||||||||
template <typename T2> | ||||||||||||||
static inline auto apply(const T2& x) { | ||||||||||||||
return F::fun(x); | ||||||||||||||
} | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Should this also be forwarding? As the downstream calls are forwarding their arguments to |
||||||||||||||
}; | ||||||||||||||
|
||||||||||||||
} // namespace math | ||||||||||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -87,9 +87,11 @@ cholesky_corr_constrain(const EigVec& y, int K, Lp& lp) { | |
* @param K The size of the matrix to return | ||
*/ | ||
template <typename T, require_std_vector_t<T>* = nullptr> | ||
inline auto cholesky_corr_constrain(const T& y, int K) { | ||
return apply_vector_unary<T>::apply( | ||
y, [K](auto&& v) { return cholesky_corr_constrain(v, K); }); | ||
inline auto cholesky_corr_constrain(T&& y, int K) { | ||
return apply_vector_unary<std::decay_t<T>>::apply( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not something you necessarily need to change in this PR, but it might be a bit cleaner to move the |
||
std::forward<T>(y), [K](auto&& v) { | ||
return cholesky_corr_constrain(std::forward<decltype(v)>(v), K); | ||
}); | ||
} | ||
|
||
/** | ||
|
@@ -107,9 +109,11 @@ inline auto cholesky_corr_constrain(const T& y, int K) { | |
*/ | ||
template <typename T, typename Lp, require_std_vector_t<T>* = nullptr, | ||
require_convertible_t<return_type_t<T>, Lp>* = nullptr> | ||
inline auto cholesky_corr_constrain(const T& y, int K, Lp& lp) { | ||
return apply_vector_unary<T>::apply( | ||
y, [&lp, K](auto&& v) { return cholesky_corr_constrain(v, K, lp); }); | ||
inline auto cholesky_corr_constrain(T&& y, int K, Lp& lp) { | ||
return apply_vector_unary<std::decay_t<T>>::apply( | ||
std::forward<T>(y), [&lp, K](auto&& v) { | ||
return cholesky_corr_constrain(std::forward<decltype(v)>(v), K, lp); | ||
}); | ||
} | ||
|
||
/** | ||
|
@@ -132,11 +136,11 @@ inline auto cholesky_corr_constrain(const T& y, int K, Lp& lp) { | |
*/ | ||
template <bool Jacobian, typename T, typename Lp, | ||
require_convertible_t<return_type_t<T>, Lp>* = nullptr> | ||
inline auto cholesky_corr_constrain(const T& y, int K, Lp& lp) { | ||
inline auto cholesky_corr_constrain(T&& y, int K, Lp& lp) { | ||
if constexpr (Jacobian) { | ||
return cholesky_corr_constrain(y, K, lp); | ||
return cholesky_corr_constrain(std::forward<T>(y), K, lp); | ||
} else { | ||
return cholesky_corr_constrain(y, K); | ||
return cholesky_corr_constrain(std::forward<T>(y), K); | ||
} | ||
} | ||
|
||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -24,8 +24,8 @@ namespace math { | |||||
* @return tanh transform | ||||||
*/ | ||||||
template <typename T> | ||||||
inline plain_type_t<T> corr_constrain(const T& x) { | ||||||
return tanh(x); | ||||||
inline plain_type_t<T> corr_constrain(T&& x) { | ||||||
return tanh(std::forward<T>(x)); | ||||||
} | ||||||
|
||||||
/** | ||||||
|
@@ -43,7 +43,7 @@ inline plain_type_t<T> corr_constrain(const T& x) { | |||||
* @param[in,out] lp log density accumulator | ||||||
*/ | ||||||
template <typename T_x, typename T_lp> | ||||||
inline auto corr_constrain(const T_x& x, T_lp& lp) { | ||||||
inline auto corr_constrain(T_x&& x, T_lp& lp) { | ||||||
plain_type_t<T_x> tanh_x = tanh(x); | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
lp += sum(log1m(square(tanh_x))); | ||||||
return tanh_x; | ||||||
|
@@ -65,11 +65,11 @@ inline auto corr_constrain(const T_x& x, T_lp& lp) { | |||||
* @param[in,out] lp log density accumulator | ||||||
*/ | ||||||
template <bool Jacobian, typename T_x, typename T_lp> | ||||||
inline auto corr_constrain(const T_x& x, T_lp& lp) { | ||||||
inline auto corr_constrain(T_x&& x, T_lp& lp) { | ||||||
if constexpr (Jacobian) { | ||||||
return corr_constrain(x, lp); | ||||||
return corr_constrain(std::forward<T_x>(x), lp); | ||||||
} else { | ||||||
return corr_constrain(x); | ||||||
return corr_constrain(std::forward<T_x>(x)); | ||||||
} | ||||||
} | ||||||
|
||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
apply_vector_unary
functors themselves should probably also perfect-forwarding, since they'll be passing their inputs toapply_*
functions as well.Also probably best to remove the reference-capture default while we're here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you clarify? I'm not seeing in this function how much forwarding can be done. I do the perfect forwarding in the actual code for
apply_vector_unary
etc if that is was you mean.Agree