Skip to content

Instantly share code, notes, and snippets.

@mcobzarenco
Created December 8, 2014 18:00
Show Gist options
  • Save mcobzarenco/cc23946abb41b1a182f2 to your computer and use it in GitHub Desktop.
Save mcobzarenco/cc23946abb41b1a182f2 to your computer and use it in GitHub Desktop.
template<typename Scalar>
struct ImmutableParams {
std::vector<Eigen::Map<const DynamicMatrix<Scalar>>> W;
std::vector<Eigen::Map<const DynamicVector<Scalar>>> b;
};
template<typename Scalar>
struct MutableParams {
std::vector<Eigen::Map<DynamicMatrix<Scalar>>> W;
std::vector<Eigen::Map<DynamicVector<Scalar>>> b;
};
template<typename ScalarPtr, typename Params>
static inline Params map_weights_as_params(
const Layers& layers, Params& params, ScalarPtr head) {
for (uint32_t layer = 0; layer < layers.size() - 1; ++layer) {
params.W.emplace_back(head, layers[layer + 1], layers[layer]);
head += layers[layer + 1] * layers[layer];
params.b.emplace_back(head, layers[layer + 1]);
head += layers[layer + 1];
};
return params;
}
template<typename Scalar>
static inline ImmutableParams<Scalar> weights_as_params(
const Layers& layers, const DynamicVector<Scalar>& weights) {
CHECK_EQ(weights.size(), FeedForward::num_params(layers));
ImmutableParams<Scalar> params;
map_weights_as_params(layers, params, weights.data());
return params;
}
template<typename Scalar>
static inline MutableParams<Scalar> weights_as_params(
const Layers& layers, DynamicVector<Scalar>& weights) {
CHECK_EQ(weights.size(), FeedForward::num_params(layers));
MutableParams<Scalar> params;
map_weights_as_params(layers, params, weights.data());
return params;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment