Skip to content

Instantly share code, notes, and snippets.

@SteveBronder
Created October 14, 2020 14:13
Show Gist options
  • Save SteveBronder/342962c01867e3765f8b2f09afc8999f to your computer and use it in GitHub Desktop.
Save SteveBronder/342962c01867e3765f8b2f09afc8999f to your computer and use it in GitHub Desktop.
#include <benchmark/benchmark.h>
#include <stan/math.hpp>
#include <utility>
static bool needs_done = true;
// Just to fill up the stack allocator
template <int max_alloc>
static void toss_me(benchmark::State& state) {
using stan::math::var;
if (needs_done) {
needs_done = false;
using stan::math::var;
using stan::math::sum;
Eigen::Matrix<var, -1, -1> x(Eigen::MatrixXd::Random(max_alloc, max_alloc));
Eigen::Matrix<var, -1, -1> y(Eigen::MatrixXd::Random(max_alloc, max_alloc));
Eigen::Matrix<var, -1, -1> z = elt_multiply(x, y);
var lp = sum(z);
benchmark::DoNotOptimize(lp.vi_);
for (auto _ : state) {
lp.grad();
stan::math::set_zero_all_adjoints();
}
stan::math::recover_memory();
}
}
namespace stan {
namespace math {
/**
* Returns the dot product.
*
* @tparam T1 type of elements in the first vector
* @tparam T2 type of elements in the second vector
*
* @param[in] v1 First vector.
* @param[in] v2 Second vector.
* @return Dot product of the vectors.
* @throw std::domain_error if sizes of v1 and v2 do not match.
*/
template <typename T1, typename T2, require_all_container_t<T1, T2>* = nullptr,
require_any_vt_var<T1, T2>* = nullptr>
inline var dot_product_cf(const T1& v1, const T2& v2) {
check_matching_sizes("dot_product", "v1", v1, "v2", v2);
if (!is_constant<T1>::value && !is_constant<T2>::value) {
arena_t<vector_v> v1_arena = as_column_vector_or_scalar(v1);
arena_t<vector_v> v2_arena = as_column_vector_or_scalar(v2);
var res(v1_arena.val().dot(v2_arena.val()));
reverse_pass_callback([v1_arena, v2_arena, res]() mutable {
for (Eigen::Index i = 0; i < v1_arena.size(); ++i) {
v1_arena.coeffRef(i).adj() += res.adj() * v2_arena.coeffRef(i).val();
v2_arena.coeffRef(i).adj() += res.adj() * v1_arena.coeffRef(i).val();
}
});
return res;
} else if (!is_constant<T2>::value) {
arena_t<vector_v> v2_arena = as_column_vector_or_scalar(v2);
arena_t<Eigen::VectorXd> v1_val_arena = value_of(as_column_vector_or_scalar(v1));
var res(v1_val_arena.dot(v2_arena.val()));
reverse_pass_callback([v1_val_arena, v2_arena, res]() mutable {
v2_arena.adj() += res.adj() * v1_val_arena;
});
return res;
} else {
arena_t<vector_v> v1_arena = as_column_vector_or_scalar(v1);
arena_t<Eigen::VectorXd> v2_val_arena = value_of(as_column_vector_or_scalar(v2));
var res(v1_arena.val().dot(v2_val_arena));
reverse_pass_callback([v1_arena, v2_val_arena, res]() mutable {
v1_arena.adj() += res.adj() * v2_val_arena;
});
return res;
}
}
inline var dot_product(const Eigen::Matrix<var, -1, 1>& v1,
const Eigen::Matrix<var, -1, 1>& v2) {
check_matching_sizes("dot_product", "v1", v1, "v2", v2);
const auto& v1_col = as_column_vector_or_scalar(v1);
const auto& v2_col = as_column_vector_or_scalar(v2);
arena_t<Eigen::VectorXd> v1_val_arena = to_arena(value_of(v1_col));
arena_t<Eigen::VectorXd> v2_val_arena = to_arena(value_of(v2_col));
double res_val = dot_product(v1_val_arena, v2_val_arena);
var res(res_val);
arena_t<Eigen::Matrix<var, Eigen::Dynamic, 1>> v1_arena = to_arena(v1_col);
arena_t<Eigen::Matrix<var, Eigen::Dynamic, 1>> v2_arena = to_arena(v2_col);
reverse_pass_callback(
[v1_arena, v2_arena, res, v1_val_arena, v2_val_arena]() mutable {
v1_arena.adj() += res.adj() * v2_val_arena;
v2_arena.adj() += res.adj() * v1_val_arena;
});
return res;
}
} // namespace math
} // namespace stan
template <typename T1, typename T2>
static void dot_product_test(benchmark::State& state) {
using stan::math::var;
using stan::math::promote_scalar_t;
using stan::math::dot_product;
Eigen::VectorXd x_val = Eigen::VectorXd::Random(state.range(0));
Eigen::VectorXd y_val = Eigen::VectorXd::Random(state.range(0));
for (auto _ : state) {
Eigen::Matrix<T1, -1, 1> x(x_val);
Eigen::Matrix<T2, -1, 1> y(y_val);
auto start = std::chrono::high_resolution_clock::now();
var lp = dot_product(x, y);
lp.grad();
// Clobber here to make sure all the reads and writes are finished
benchmark::ClobberMemory();
auto end = std::chrono::high_resolution_clock::now();
auto elapsed_seconds =
std::chrono::duration_cast<std::chrono::duration<double>>(end - start);
state.SetIterationTime(elapsed_seconds.count());
stan::math::recover_memory();
benchmark::ClobberMemory();
}
}
template <typename T1, typename T2>
static void dot_product_cf_test(benchmark::State& state) {
using stan::math::var;
using stan::math::promote_scalar_t;
using stan::math::dot_product;
Eigen::VectorXd x_val = Eigen::VectorXd::Random(state.range(0));
Eigen::VectorXd y_val = Eigen::VectorXd::Random(state.range(0));
for (auto _ : state) {
Eigen::Matrix<T1, -1, 1> x(x_val);
Eigen::Matrix<T2, -1, 1> y(y_val);
auto start = std::chrono::high_resolution_clock::now();
var lp = dot_product_cf(x, y);
lp.grad();
// Clobber here to make sure all the reads and writes are finished
benchmark::ClobberMemory();
auto end = std::chrono::high_resolution_clock::now();
auto elapsed_seconds =
std::chrono::duration_cast<std::chrono::duration<double>>(end - start);
state.SetIterationTime(elapsed_seconds.count());
stan::math::recover_memory();
benchmark::ClobberMemory();
}
}
using stan::math::var;
// The start and ending sizes for the benchmark
constexpr int start_val = 2;
constexpr int end_val = 4096;
constexpr int mem_size = 4096 * 1.4;
// Allocate a big block of mem to replicate first iteration of Stan program.
BENCHMARK_TEMPLATE(toss_me, mem_size);
BENCHMARK_TEMPLATE(dot_product_test, var, var)->RangeMultiplier(2)->Range(start_val, end_val)->UseManualTime();
BENCHMARK_TEMPLATE(dot_product_cf_test, var, var)->RangeMultiplier(2)->Range(start_val, end_val)->UseManualTime();
BENCHMARK_TEMPLATE(dot_product_test, double, var)->RangeMultiplier(2)->Range(start_val, end_val)->UseManualTime();
BENCHMARK_TEMPLATE(dot_product_cf_test, double, var)->RangeMultiplier(2)->Range(start_val, end_val)->UseManualTime();
BENCHMARK_TEMPLATE(dot_product_test, var, double)->RangeMultiplier(2)->Range(start_val, end_val)->UseManualTime();
BENCHMARK_TEMPLATE(dot_product_cf_test, var, double)->RangeMultiplier(2)->Range(start_val, end_val)->UseManualTime();
BENCHMARK_MAIN();
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment