Skip to content

Instantly share code, notes, and snippets.

@alecjacobson
Last active October 15, 2023 15:14
Show Gist options
  • Save alecjacobson/58cdbb4b2729265657ac375905a3014f to your computer and use it in GitHub Desktop.
Save alecjacobson/58cdbb4b2729265657ac375905a3014f to your computer and use it in GitHub Desktop.
Use C++17 auto lambdas to pass a templated function to an optimizer that will call it with both double and autodiff::real types.
// requires -std=c++17
#include <stdio.h>
#include <autodiff/forward/real.hpp> // https://github.com/autodiff/autodiff/
/// Take a gradient descent step at a given function
///
/// @tparam FuncType should be an auto-lambda (i.e., `[](auto x)->auto{…}`)
/// @param[in] func scalar function
/// @param[in] initial_guess initial input to func
/// @return initial_guess - dfdx where dfdx is the derivative of func at initial_guess
///
template <typename FuncType>
double scalar_descent(FuncType func, double initial_guess) {
double value = initial_guess;
double func_value = func(value);
autodiff::real x = initial_guess;
double dfunc_dvalue = autodiff::derivative(
[&func](autodiff::real x)->autodiff::real { return func(x); },
autodiff::wrt(x),autodiff::at(x));
return value - dfunc_dvalue;
}
/// Templated example function
template <typename Scalar>
Scalar f(Scalar x) {
// Print revealing which compiled template was called
if constexpr (std::is_same_v<Scalar, double>) {
printf("calling f<double>\n");
}
if constexpr (std::is_same_v<Scalar, autodiff::real>) {
printf("calling f<autodiff::real>\n");
}
return x * x;
}
int main() {
double initial = 1.0;
// wrap f in an auto lambda so it can be called with either double or autodiff::real
double result = scalar_descent([](auto x)->auto{return f(x);}, initial);
}
@alecjacobson
Copy link
Author

I like this because I'm used to writing templated functions like

template <typename Scalar>
Scalar f(Scalar x) { … }

It's a little gross that the line

double result = scalar_descent([](auto x)->auto{return f(x);}, initial);

needs the auto-lambda wrapper around f, but it's also common for me to use a lambda to capture up local variables before calling a generic function.

@alecjacobson
Copy link
Author

alecjacobson commented Oct 15, 2023

An alternative that will work without C++17 (and the same implementation of scalar_descent is to put f into a struct that has a templated operator():

struct f{
  template <typename Scalar>
  Scalar operator()(Scalar x) {
    // Print revealing which compiled template was called
    if constexpr (std::is_same_v<Scalar, double>) {
      printf("calling f::operator()<double>\n");
    }
    if constexpr (std::is_same_v<Scalar, autodiff::real>) {
      printf("calling f::operator()<autodiff::real>\n");
    }

    return x * x;
  }
};

int main() {
  double initial = 1.0;
  // wrap f in an auto lambda so it can be called with either double or autodiff::real
  double result = scalar_descent(f(), initial);
}

I don't like this as much because this functor pattern is unnatural to me and the using either scalar_descent(f(), initial); or

f f_intance; 
double result = scalar_descent(f(), initial);

obscures that f is supposed to be realizing a function f(x).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment