Skip to content

Instantly share code, notes, and snippets.

@freedomtowin
Created June 6, 2018 00:35
Show Gist options
  • Save freedomtowin/5a79f4df83914b8196d952979cfb4b50 to your computer and use it in GitHub Desktop.
Save freedomtowin/5a79f4df83914b8196d952979cfb4b50 to your computer and use it in GitHub Desktop.
from scipy import stats
class NNShapeHelper():
def __init__(self,layer_shape,num_inputs,num_outputs):
self.N_inputs = num_inputs
self.N_outputs = num_outputs
self.layer_shape = layer_shape
self.N_layers = len(layer_shape)
self.model_shape = []
self.parameter_shape = []
def get_N_parameters(self):
self.model_shape.append(self.N_inputs)
input_n_parameters = self.N_inputs*self.layer_shape[0]
N = input_n_parameters
self.parameter_shape.append(input_n_parameters)
for i in range(1,self.N_layers):
layer_n_parameters = self.layer_shape[i-1]*self.layer_shape[i]
self.model_shape.append(self.layer_shape[i])
self.parameter_shape.append(layer_n_parameters)
N += layer_n_parameters
output_n_parameters = self.N_outputs*self.layer_shape[-1]
N += output_n_parameters
self.model_shape.append(self.N_outputs)
self.parameter_shape.append(output_n_parameters)
self.N_parameters = N
return N
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment