Skip to content

Instantly share code, notes, and snippets.

@sadjad
Created October 9, 2021 18:43
Show Gist options
  • Save sadjad/e11b01dc85903e0f07eb5e3f3688e5ae to your computer and use it in GitHub Desktop.
Save sadjad/e11b01dc85903e0f07eb5e3f3688e5ae to your computer and use it in GitHub Desktop.
#include <array>
#include <iostream>
using namespace std;
template<unsigned int n0, unsigned int... rest>
struct last_element
{
constexpr static unsigned int value = last_element<rest...>::value;
};
template<unsigned int n>
struct last_element<n>
{
constexpr static unsigned int value = n;
};
template<unsigned int input_size_, unsigned int output_size_>
class Layer
{
private:
array<float, output_size_> output_;
public:
void apply( const array<float, input_size_>& input ) {}
const array<float, output_size_>& output() const { return output_; }
unsigned int input_size() const { return input_size_; }
unsigned int output_size() const { return output_size_; }
};
template<unsigned int i0, unsigned int o0, unsigned int... rest>
class Network
{
public:
constexpr static unsigned int input_size = i0;
constexpr static unsigned int output_size = last_element<o0, rest...>::value;
Layer<i0, o0> layer0;
Network<o0, rest...> next;
void apply( const array<float, i0>& input )
{
layer0.apply( input );
next.apply( layer0.output() );
}
const array<float, output_size>& output() { return next.output(); }
};
// BASE CASE
template<unsigned int i0, unsigned int o0>
class Network<i0, o0>
{
public:
Layer<i0, o0> layer0;
void apply( const array<float, i0>& input ) { layer0.apply( input ); }
const array<float, o0>& output() { return layer0.output(); }
};
int main()
{
Network<20, 10, 5, 2> nn;
cout << "input size: " << nn.input_size << endl;
cout << nn.layer0.input_size() << " -> " << nn.layer0.output_size() << endl;
cout << nn.next.layer0.input_size() << " -> " << nn.next.layer0.output_size() << endl;
cout << nn.next.next.layer0.input_size() << " -> " << nn.next.next.layer0.output_size() << endl;
cout << "output size: " << nn.output_size << endl;
array<float, 20> input;
nn.apply( input );
auto& output = nn.output();
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment