-
-
Save DarienBrito/ec55e35f0c172f451dc86d6aaa200d04 to your computer and use it in GitHub Desktop.
very quick & simple dictionary / json based graph builder for tensorflow
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python2 | |
# -*- coding: utf-8 -*- | |
""" | |
Created on Tue Feb 27 02:46:51 2018 | |
@author: memo | |
very quick & simple dictionary / json based graph builder for tensorflow | |
( inspired by https://github.com/dribnet/discgen/blob/master/discgen/vae.py#L43-L163 ) | |
""" | |
from __future__ import print_function | |
from __future__ import division | |
import tensorflow as tf | |
import numpy as np | |
from pprint import pprint | |
import msa.tf.ops | |
def example(): | |
tf.reset_default_graph() | |
# dict of dicts { { <opname> : kwargs }, ... } | |
default_op_args = { | |
'conv2d' : { 'padding':'same', 'kernel_size':(3,3), 'strides':(1,1) }, | |
'conv2d_transpose' : { 'kernel_size':(2,2), 'strides':(2,2) }, | |
} | |
# list of dicts [ {'op':<opname>, kwargs }, ... ] | |
encoder_ops_info = [ | |
{ 'op':'conv2d', 'filters':64 }, { 'op':'batch_norm' }, { 'op':'relu' }, | |
{ 'op':'conv2d', 'kernel_size':(2,2), 'strides':(2,2), 'filters':64 }, { 'op':'batch_norm' }, { 'op':'relu' }, | |
{ 'op':'conv2d', 'filters':128 }, { 'op':'batch_norm' }, { 'op':'relu' }, | |
{ 'op':'conv2d', 'kernel_size':(2,2), 'strides':(2,2), 'filters':128 }, { 'op':'batch_norm' }, { 'op':'relu' }, | |
{ 'op':'conv2d', 'filters':256 }, { 'op':'batch_norm' }, { 'op':'relu' }, | |
{ 'op':'conv2d', 'kernel_size':(2,2), 'strides':(2,2), 'filters':256 }, { 'op':'batch_norm' }, { 'op':'relu' }, | |
{ 'op':'identity', 'name':'pre_z_conv' }, | |
{ 'op':'flatten' }, | |
{ 'op':'dense', 'units':1024 }, { 'op':'batch_norm' }, { 'op':'relu' }, | |
] | |
decoder_ops_info = [ | |
{ 'op':'dense', 'units':128, 'name':'z' }, { 'op':'batch_norm' }, { 'op':'relu' }, | |
{ 'op':'dense', 'units':0, 'name':'post_z_flat' }, { 'op':'batch_norm' }, { 'op':'relu' }, | |
{ 'op':'tf.reshape', 'name':'post_z_conv' }, { 'op':'batch_norm' }, { 'op':'relu' }, | |
{ 'op':'conv2d', 'filters':256 }, { 'op':'batch_norm' }, { 'op':'relu' }, | |
{ 'op':'conv2d_transpose','filters':256 }, { 'op':'batch_norm' }, { 'op':'relu' }, | |
{ 'op':'conv2d', 'filters':128 }, { 'op':'batch_norm' }, { 'op':'relu' }, | |
{ 'op':'conv2d_transpose', 'filters':128 }, { 'op':'batch_norm' }, { 'op':'relu' }, | |
{ 'op':'conv2d', 'filters':64 }, { 'op':'batch_norm' }, { 'op':'relu' }, | |
{ 'op':'conv2d_transpose', 'filters':64 }, { 'op':'batch_norm' }, { 'op':'relu' }, | |
{ 'op':'conv2d', 'kernel_size':(1,1), 'filters':3 }, { 'op':'tanh', 'name':'output'} | |
] | |
x = tf.placeholder(tf.float32, [None, 64, 64, 3]) | |
# build encoder | |
with tf.variable_scope('encoder'): | |
encoder_ops, errors = build_graph(x, encoder_ops_info, default_op_args) | |
# TODO: THIS BIT IS UGLY. is there a better way of automating all of this? | |
# need to get the conv shape before flattening. search encoder tensors by name | |
pre_z_conv = get_tensors_by_name(encoder_ops, 'pre_z_conv')[0] | |
# write to decoder_ops_info | |
# flattened shape is multiplication of all dims except for batch size | |
get_ops_by_name(decoder_ops_info, 'post_z_flat')[0]['units'] = np.prod(pre_z_conv.shape[1:]) | |
# first conv op after flat layer needs write shape. | |
get_ops_by_name(decoder_ops_info, 'post_z_conv')[0]['shape'] = tf.shape(pre_z_conv) | |
# now build decoder | |
with tf.variable_scope('decoder'): | |
decoder_ops, errors = build_graph(encoder_ops[-1], decoder_ops_info, default_op_args) | |
return encoder_ops, decoder_ops | |
''' | |
Output: | |
-------------------------------------------------------------------------------- | |
> msa.tf.ops.conv2d {'filters': 64} + defaults {'padding': 'same', 'strides': (1, 1), 'kernel_size': (3, 3)} --> Tensor("encoder/conv2d/BiasAdd:0", shape=(?, 256, 256, 64), dtype=float32) | |
> msa.tf.ops.batch_norm {} --> Tensor("encoder/batch_normalization/FusedBatchNorm:0", shape=(?, 256, 256, 64), dtype=float32) | |
> msa.tf.ops.relu {} --> Tensor("encoder/Relu:0", shape=(?, 256, 256, 64), dtype=float32) | |
> msa.tf.ops.conv2d {'strides': (2, 2), 'kernel_size': (2, 2), 'filters': 64} + defaults {'padding': 'same'} --> Tensor("encoder/conv2d_2/BiasAdd:0", shape=(?, 128, 128, 64), dtype=float32) | |
> msa.tf.ops.batch_norm {} --> Tensor("encoder/batch_normalization_2/FusedBatchNorm:0", shape=(?, 128, 128, 64), dtype=float32) | |
> msa.tf.ops.relu {} --> Tensor("encoder/Relu_1:0", shape=(?, 128, 128, 64), dtype=float32) | |
> msa.tf.ops.conv2d {'filters': 128} + defaults {'padding': 'same', 'strides': (1, 1), 'kernel_size': (3, 3)} --> Tensor("encoder/conv2d_3/BiasAdd:0", shape=(?, 128, 128, 128), dtype=float32) | |
> msa.tf.ops.batch_norm {} --> Tensor("encoder/batch_normalization_3/FusedBatchNorm:0", shape=(?, 128, 128, 128), dtype=float32) | |
> msa.tf.ops.relu {} --> Tensor("encoder/Relu_2:0", shape=(?, 128, 128, 128), dtype=float32) | |
> msa.tf.ops.conv2d {'strides': (2, 2), 'kernel_size': (2, 2), 'filters': 128} + defaults {'padding': 'same'} --> Tensor("encoder/conv2d_4/BiasAdd:0", shape=(?, 64, 64, 128), dtype=float32) | |
> msa.tf.ops.batch_norm {} --> Tensor("encoder/batch_normalization_4/FusedBatchNorm:0", shape=(?, 64, 64, 128), dtype=float32) | |
> msa.tf.ops.relu {} --> Tensor("encoder/Relu_3:0", shape=(?, 64, 64, 128), dtype=float32) | |
> msa.tf.ops.conv2d {'filters': 256} + defaults {'padding': 'same', 'strides': (1, 1), 'kernel_size': (3, 3)} --> Tensor("encoder/conv2d_5/BiasAdd:0", shape=(?, 64, 64, 256), dtype=float32) | |
> msa.tf.ops.batch_norm {} --> Tensor("encoder/batch_normalization_5/FusedBatchNorm:0", shape=(?, 64, 64, 256), dtype=float32) | |
> msa.tf.ops.relu {} --> Tensor("encoder/Relu_4:0", shape=(?, 64, 64, 256), dtype=float32) | |
> msa.tf.ops.conv2d {'strides': (2, 2), 'kernel_size': (2, 2), 'filters': 256} + defaults {'padding': 'same'} --> Tensor("encoder/conv2d_6/BiasAdd:0", shape=(?, 32, 32, 256), dtype=float32) | |
> msa.tf.ops.batch_norm {} --> Tensor("encoder/batch_normalization_6/FusedBatchNorm:0", shape=(?, 32, 32, 256), dtype=float32) | |
> msa.tf.ops.relu {} --> Tensor("encoder/Relu_5:0", shape=(?, 32, 32, 256), dtype=float32) | |
> msa.tf.ops.identity {'name': 'pre_z_conv'} --> Tensor("encoder/pre_z_conv:0", shape=(?, 32, 32, 256), dtype=float32) | |
> msa.tf.ops.flatten {} --> Tensor("encoder/Flatten/flatten/Reshape:0", shape=(?, 262144), dtype=float32) | |
> msa.tf.ops.dense {'units': 1024} --> Tensor("encoder/dense/BiasAdd:0", shape=(?, 1024), dtype=float32) | |
> msa.tf.ops.batch_norm {} --> Tensor("encoder/batch_normalization_7/batchnorm/add_1:0", shape=(?, 1024), dtype=float32) | |
> msa.tf.ops.relu {} --> Tensor("encoder/Relu_6:0", shape=(?, 1024), dtype=float32) | |
-------------------------------------------------------------------------------- | |
23 ops added | |
-------------------------------------------------------------------------------- | |
> msa.tf.ops.dense {'units': 128, 'name': 'z'} --> Tensor("decoder/z/BiasAdd:0", shape=(?, 128), dtype=float32) | |
> msa.tf.ops.batch_norm {} --> Tensor("decoder/batch_normalization/batchnorm/add_1:0", shape=(?, 128), dtype=float32) | |
> msa.tf.ops.relu {} --> Tensor("decoder/Relu:0", shape=(?, 128), dtype=float32) | |
> msa.tf.ops.dense {'units': 1024} --> Tensor("decoder/dense/BiasAdd:0", shape=(?, 1024), dtype=float32) | |
> msa.tf.ops.batch_norm {} --> Tensor("decoder/batch_normalization_2/batchnorm/add_1:0", shape=(?, 1024), dtype=float32) | |
> msa.tf.ops.relu {} --> Tensor("decoder/Relu_1:0", shape=(?, 1024), dtype=float32) | |
> tf.reshape {'shape': <tf.Tensor 'Shape:0' shape=(4,) dtype=int32>, 'name': 'post_z_conv'} --> Tensor("decoder/post_z_conv:0", shape=(?, 32, 32, 256), dtype=float32) | |
> msa.tf.ops.batch_norm {} --> Tensor("decoder/batch_normalization_3/FusedBatchNorm:0", shape=(?, 32, 32, 256), dtype=float32) | |
> msa.tf.ops.relu {} --> Tensor("decoder/Relu_2:0", shape=(?, 32, 32, 256), dtype=float32) | |
> msa.tf.ops.conv2d {'filters': 256} + defaults {'padding': 'same', 'strides': (1, 1), 'kernel_size': (3, 3)} --> Tensor("decoder/conv2d/BiasAdd:0", shape=(?, 32, 32, 256), dtype=float32) | |
> msa.tf.ops.batch_norm {} --> Tensor("decoder/batch_normalization_4/FusedBatchNorm:0", shape=(?, 32, 32, 256), dtype=float32) | |
> msa.tf.ops.relu {} --> Tensor("decoder/Relu_3:0", shape=(?, 32, 32, 256), dtype=float32) | |
> msa.tf.ops.conv2d_transpose {'filters': 256} + defaults {'strides': (2, 2), 'kernel_size': (2, 2)} --> Tensor("decoder/conv2d_transpose/BiasAdd:0", shape=(?, 64, 64, 256), dtype=float32) | |
> msa.tf.ops.batch_norm {} --> Tensor("decoder/batch_normalization_5/FusedBatchNorm:0", shape=(?, 64, 64, 256), dtype=float32) | |
> msa.tf.ops.relu {} --> Tensor("decoder/Relu_4:0", shape=(?, 64, 64, 256), dtype=float32) | |
> msa.tf.ops.conv2d {'filters': 128} + defaults {'padding': 'same', 'strides': (1, 1), 'kernel_size': (3, 3)} --> Tensor("decoder/conv2d_2/BiasAdd:0", shape=(?, 64, 64, 128), dtype=float32) | |
> msa.tf.ops.batch_norm {} --> Tensor("decoder/batch_normalization_6/FusedBatchNorm:0", shape=(?, 64, 64, 128), dtype=float32) | |
> msa.tf.ops.relu {} --> Tensor("decoder/Relu_5:0", shape=(?, 64, 64, 128), dtype=float32) | |
> msa.tf.ops.conv2d_transpose {'filters': 128} + defaults {'strides': (2, 2), 'kernel_size': (2, 2)} --> Tensor("decoder/conv2d_transpose_2/BiasAdd:0", shape=(?, 128, 128, 128), dtype=float32) | |
> msa.tf.ops.batch_norm {} --> Tensor("decoder/batch_normalization_7/FusedBatchNorm:0", shape=(?, 128, 128, 128), dtype=float32) | |
> msa.tf.ops.relu {} --> Tensor("decoder/Relu_6:0", shape=(?, 128, 128, 128), dtype=float32) | |
> msa.tf.ops.conv2d {'filters': 64} + defaults {'padding': 'same', 'strides': (1, 1), 'kernel_size': (3, 3)} --> Tensor("decoder/conv2d_3/BiasAdd:0", shape=(?, 128, 128, 64), dtype=float32) | |
> msa.tf.ops.batch_norm {} --> Tensor("decoder/batch_normalization_8/FusedBatchNorm:0", shape=(?, 128, 128, 64), dtype=float32) | |
> msa.tf.ops.relu {} --> Tensor("decoder/Relu_7:0", shape=(?, 128, 128, 64), dtype=float32) | |
> msa.tf.ops.conv2d_transpose {'filters': 64} + defaults {'strides': (2, 2), 'kernel_size': (2, 2)} --> Tensor("decoder/conv2d_transpose_3/BiasAdd:0", shape=(?, 256, 256, 64), dtype=float32) | |
> msa.tf.ops.batch_norm {} --> Tensor("decoder/batch_normalization_9/FusedBatchNorm:0", shape=(?, 256, 256, 64), dtype=float32) | |
> msa.tf.ops.relu {} --> Tensor("decoder/Relu_8:0", shape=(?, 256, 256, 64), dtype=float32) | |
> msa.tf.ops.conv2d {'kernel_size': (1, 1), 'filters': 3} + defaults {'padding': 'same', 'strides': (1, 1)} --> Tensor("decoder/conv2d_4/BiasAdd:0", shape=(?, 256, 256, 3), dtype=float32) | |
> msa.tf.ops.tanh {'name': 'output'} --> Tensor("decoder/output:0", shape=(?, 256, 256, 3), dtype=float32) | |
-------------------------------------------------------------------------------- | |
29 ops added | |
''' | |
#%% | |
namespaces=[ | |
'', | |
'msa.tf.ops', | |
'tf', | |
'tf.layers', | |
'tf.nn', | |
'tf.contrib.layers' | |
] | |
def get_tensors_by_name(tensors, name): | |
'''given a list of tensors, return any tensor which has matching name''' | |
return filter(lambda x: name in x.name, tensors) | |
def get_ops_by_name(ops_info, name): | |
'''given a list of op info dicts, return any op dict which has matching name''' | |
return filter(lambda x: 'name' in x and name in x['name'], ops_info) | |
def build_graph(input_T, ops_info, default_op_args=None, verbose=True): | |
print('-'*80) | |
errors = [] | |
def handle_error(msg, op_dict): | |
print('\n** ERROR', msg, op_dict,'\n') | |
errors.append( {msg : op_dict} ) | |
t = input_T | |
ops = [] | |
for op_dict in ops_info: | |
if type(op_dict) == dict: | |
if 'op' in op_dict: | |
op_str = op_dict['op'] # get dict for this layer | |
op_fn = None | |
fn_path = None | |
for namespace in namespaces: | |
try: | |
fn_path = '.'.join([namespace, op_str]) if namespace else op_str | |
op_fn = eval(fn_path) | |
break | |
except: pass | |
if op_fn: | |
# get op args excluding op name | |
args = { k:v for k,v in op_dict.items() if k!='op' } | |
if verbose: print('>', fn_path, args, end=' ') | |
extra_args = None | |
if default_op_args and op_str in default_op_args: # check for defaults | |
op_defaults = default_op_args[op_str] # defaults dict for this op type | |
extra_args = { k:v for k,v in op_defaults.items() if k not in args } | |
if extra_args: | |
if verbose: print('+ defaults', extra_args, end=' ') | |
args.update(extra_args) | |
try: | |
t = op_fn(t, **args) | |
print('-->', t) | |
ops.append(t) | |
except Exception as e: | |
handle_error(fn_path + ' : ' + str(e), op_dict) | |
else: # if op_fn: | |
handle_error('function not found', op_dict) | |
else: # if 'op' in op_dict: | |
handle_error('missing op key', op_dict) | |
else: # type(op_dict) == dict: | |
handle_error('unknown entry type', op_dict) | |
print('-'*80) | |
print('{} ops added'.format(len(ops))) | |
if len(errors) > 0: | |
print('{} errors found:'.format(len(errors))) | |
pprint(errors) | |
return ops, errors | |
#%% | |
if __name__ == "__main__": | |
encoder_ops, decoder_ops = example() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment