Created
August 16, 2018 09:26
-
-
Save DEKHTIARJonathan/e2e2e9c577e45325c009a80ead50ca51 to your computer and use it in GitHub Desktop.
Non Sequential API Attempt
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import inspect | |
import time | |
import wrapt | |
import pprint | |
def get_caller(skip=2): | |
try: | |
stack = inspect.stack() | |
if len(stack) < skip + 1: | |
raise ValueError("The length of the inspection stack is shorter than the requested start position.") | |
for currstack in stack[skip:]: | |
try: | |
args, _, _, values = inspect.getargvalues(currstack[0]) | |
if 'self' in values.keys() and isinstance(values['self'], Network): | |
return values['self'] | |
except: | |
continue | |
return None | |
except: | |
return None | |
@wrapt.decorator | |
def pass_through(wrapped, instance, args, kwargs): | |
cls = wrapped(*args, **kwargs) | |
try: | |
caller = get_caller() | |
if isinstance(caller, Network): | |
caller.add_layer(instance) | |
except: | |
pass | |
return cls | |
class My_obj(object): | |
@pass_through | |
def __init__(self, name): | |
self.name = name | |
print("[TL] Layer:", name) | |
print('no context') | |
obj = My_obj("thomas") | |
print("\nwith context") | |
class Network(object): | |
def __init__(self): | |
self.all_layer = dict() | |
def define_add_layers(): | |
def relu_layer(): | |
obj = My_obj(name="ReLU") | |
obj = My_obj(name="Dropout") | |
obj = My_obj(name="LSTM") | |
relu_layer() | |
obj = My_obj(name="Dense") | |
obj = My_obj(name="Conv2d") | |
obj = My_obj(name="BatchNorm") | |
define_add_layers() | |
def add_layer(self, layer): | |
self.all_layer[layer.name] = layer | |
net = Network() | |
pprint.pprint(net.all_layer) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment