Created
October 9, 2021 18:43
-
-
Save sadjad/e11b01dc85903e0f07eb5e3f3688e5ae to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <array> | |
#include <iostream> | |
using namespace std; | |
template<unsigned int n0, unsigned int... rest> | |
struct last_element | |
{ | |
constexpr static unsigned int value = last_element<rest...>::value; | |
}; | |
template<unsigned int n> | |
struct last_element<n> | |
{ | |
constexpr static unsigned int value = n; | |
}; | |
template<unsigned int input_size_, unsigned int output_size_> | |
class Layer | |
{ | |
private: | |
array<float, output_size_> output_; | |
public: | |
void apply( const array<float, input_size_>& input ) {} | |
const array<float, output_size_>& output() const { return output_; } | |
unsigned int input_size() const { return input_size_; } | |
unsigned int output_size() const { return output_size_; } | |
}; | |
template<unsigned int i0, unsigned int o0, unsigned int... rest> | |
class Network | |
{ | |
public: | |
constexpr static unsigned int input_size = i0; | |
constexpr static unsigned int output_size = last_element<o0, rest...>::value; | |
Layer<i0, o0> layer0; | |
Network<o0, rest...> next; | |
void apply( const array<float, i0>& input ) | |
{ | |
layer0.apply( input ); | |
next.apply( layer0.output() ); | |
} | |
const array<float, output_size>& output() { return next.output(); } | |
}; | |
// BASE CASE | |
template<unsigned int i0, unsigned int o0> | |
class Network<i0, o0> | |
{ | |
public: | |
Layer<i0, o0> layer0; | |
void apply( const array<float, i0>& input ) { layer0.apply( input ); } | |
const array<float, o0>& output() { return layer0.output(); } | |
}; | |
int main() | |
{ | |
Network<20, 10, 5, 2> nn; | |
cout << "input size: " << nn.input_size << endl; | |
cout << nn.layer0.input_size() << " -> " << nn.layer0.output_size() << endl; | |
cout << nn.next.layer0.input_size() << " -> " << nn.next.layer0.output_size() << endl; | |
cout << nn.next.next.layer0.input_size() << " -> " << nn.next.next.layer0.output_size() << endl; | |
cout << "output size: " << nn.output_size << endl; | |
array<float, 20> input; | |
nn.apply( input ); | |
auto& output = nn.output(); | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment