Skip to content

Instantly share code, notes, and snippets.

@braingineer
Last active May 4, 2016 23:50
Show Gist options
  • Save braingineer/c2aa37c4bdfe7dc7691081ccf0f4941d to your computer and use it in GitHub Desktop.
Save braingineer/c2aa37c4bdfe7dc7691081ccf0f4941d to your computer and use it in GitHub Desktop.
trying to load shared variables on keras
'''
use as:
some_data = DataLayer(K.variable(some_tensor,dtype=whatever), ...)
next_layer = Dense(some_number)
out_tensor = next_layer(some_data.tensor)
'''
class DataLayer(Layer):
'''TODO: dosctring
'''
def __init__(self, input_tensor, input_dtype=None, name=None):
self.input_spec = None
self.supports_masking = False
self.uses_learning_phase = False
self.trainable = False
self.built = True
self.ignore_me = True
self.inbound_nodes = []
self.outbound_nodes = []
self.trainable_weights = []
self.non_trainable_weights = []
self.regularizers = []
self.constraints = {}
if not name:
prefix = 'static'
name = prefix + '_' + str(K.get_uid(prefix))
self.name = name
if not input_dtype:
input_dtype = K.floatx()
self.tensor = K.variable(input_tensor, dtype=input_dtype, name=name)
self.batch_input_shape = input_tensor.shape
self.tensor._keras_shape = input_tensor.shape
self.tensor._keras_history = (self, 0, 0)
self.tensor._uses_learning_phase = False
self.tensor._sideload = True
Node(self,
inbound_layers=[],
node_indices=[],
tensor_indices=[],
input_tensors=[self.tensor],
output_tensors=[self.tensor],
input_masks=[None],
output_masks=[None],
input_shapes=[self.batch_input_shape],
output_shapes=[self.batch_input_shape])
def get_config(self):
config = {'batch_input_shape': self.batch_input_shape,
'input_dtype': self.input_dtype,
'name': self.name}
return config
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment