Last active
October 15, 2023 15:14
-
-
Save alecjacobson/58cdbb4b2729265657ac375905a3014f to your computer and use it in GitHub Desktop.
Use C++17 auto lambdas to pass a templated function to an optimizer that will call it with both double and autodiff::real types.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// requires -std=c++17 | |
#include <stdio.h> | |
#include <autodiff/forward/real.hpp> // https://github.com/autodiff/autodiff/ | |
/// Take a gradient descent step at a given function | |
/// | |
/// @tparam FuncType should be an auto-lambda (i.e., `[](auto x)->auto{…}`) | |
/// @param[in] func scalar function | |
/// @param[in] initial_guess initial input to func | |
/// @return initial_guess - dfdx where dfdx is the derivative of func at initial_guess | |
/// | |
template <typename FuncType> | |
double scalar_descent(FuncType func, double initial_guess) { | |
double value = initial_guess; | |
double func_value = func(value); | |
autodiff::real x = initial_guess; | |
double dfunc_dvalue = autodiff::derivative( | |
[&func](autodiff::real x)->autodiff::real { return func(x); }, | |
autodiff::wrt(x),autodiff::at(x)); | |
return value - dfunc_dvalue; | |
} | |
/// Templated example function | |
template <typename Scalar> | |
Scalar f(Scalar x) { | |
// Print revealing which compiled template was called | |
if constexpr (std::is_same_v<Scalar, double>) { | |
printf("calling f<double>\n"); | |
} | |
if constexpr (std::is_same_v<Scalar, autodiff::real>) { | |
printf("calling f<autodiff::real>\n"); | |
} | |
return x * x; | |
} | |
int main() { | |
double initial = 1.0; | |
// wrap f in an auto lambda so it can be called with either double or autodiff::real | |
double result = scalar_descent([](auto x)->auto{return f(x);}, initial); | |
} |
An alternative that will work without C++17 (and the same implementation of scalar_descent
is to put f
into a struct that has a templated operator()
:
struct f{
template <typename Scalar>
Scalar operator()(Scalar x) {
// Print revealing which compiled template was called
if constexpr (std::is_same_v<Scalar, double>) {
printf("calling f::operator()<double>\n");
}
if constexpr (std::is_same_v<Scalar, autodiff::real>) {
printf("calling f::operator()<autodiff::real>\n");
}
return x * x;
}
};
int main() {
double initial = 1.0;
// wrap f in an auto lambda so it can be called with either double or autodiff::real
double result = scalar_descent(f(), initial);
}
I don't like this as much because this functor pattern is unnatural to me and the using either scalar_descent(f(), initial);
or
f f_intance;
double result = scalar_descent(f(), initial);
obscures that f
is supposed to be realizing a function f(x)
.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
I like this because I'm used to writing templated functions like
It's a little gross that the line
needs the auto-lambda wrapper around
f
, but it's also common for me to use a lambda to capture up local variables before calling a generic function.