Skip to content

Instantly share code, notes, and snippets.

@DEKHTIARJonathan
Created August 16, 2018 09:26
Show Gist options
  • Save DEKHTIARJonathan/e2e2e9c577e45325c009a80ead50ca51 to your computer and use it in GitHub Desktop.
Save DEKHTIARJonathan/e2e2e9c577e45325c009a80ead50ca51 to your computer and use it in GitHub Desktop.
Non Sequential API Attempt
import inspect
import time
import wrapt
import pprint
def get_caller(skip=2):
try:
stack = inspect.stack()
if len(stack) < skip + 1:
raise ValueError("The length of the inspection stack is shorter than the requested start position.")
for currstack in stack[skip:]:
try:
args, _, _, values = inspect.getargvalues(currstack[0])
if 'self' in values.keys() and isinstance(values['self'], Network):
return values['self']
except:
continue
return None
except:
return None
@wrapt.decorator
def pass_through(wrapped, instance, args, kwargs):
cls = wrapped(*args, **kwargs)
try:
caller = get_caller()
if isinstance(caller, Network):
caller.add_layer(instance)
except:
pass
return cls
class My_obj(object):
@pass_through
def __init__(self, name):
self.name = name
print("[TL] Layer:", name)
print('no context')
obj = My_obj("thomas")
print("\nwith context")
class Network(object):
def __init__(self):
self.all_layer = dict()
def define_add_layers():
def relu_layer():
obj = My_obj(name="ReLU")
obj = My_obj(name="Dropout")
obj = My_obj(name="LSTM")
relu_layer()
obj = My_obj(name="Dense")
obj = My_obj(name="Conv2d")
obj = My_obj(name="BatchNorm")
define_add_layers()
def add_layer(self, layer):
self.all_layer[layer.name] = layer
net = Network()
pprint.pprint(net.all_layer)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment