Skip to content

Instantly share code, notes, and snippets.

@standbyme
Created July 30, 2019 05:25
Show Gist options
  • Save standbyme/8c65a7cabe592e36e0a50abeea5554d2 to your computer and use it in GitHub Desktop.
Save standbyme/8c65a7cabe592e36e0a50abeea5554d2 to your computer and use it in GitHub Desktop.
SimpleRNN
def tf__helper(self, inputs):
do_return = False
retval_ = ag__.UndefinedReturnValue()
inputs = ag__.converted_call('transpose', tf, ag__.ConversionOptions(recursive=True, force_conversion=False, optional_features=(), internal_convert_user_code=True), (inputs, [1, 0, 2]), None)
batch_size = ag__.converted_call('shape', tf, ag__.ConversionOptions(recursive=True, force_conversion=False, optional_features=(), internal_convert_user_code=True), (inputs,), None)[1]
hidden = ag__.converted_call('zeros', tf, ag__.ConversionOptions(recursive=True, force_conversion=False, optional_features=(), internal_convert_user_code=True), ([batch_size, hidden_size],), {'name': 'hidden'})
output = ag__.converted_call('zeros', tf, ag__.ConversionOptions(recursive=True, force_conversion=False, optional_features=(), internal_convert_user_code=True), ([batch_size, output_size],), {'name': 'output'})
def loop_body(loop_vars, hidden_1, output_1):
x_t = loop_vars
t_1 = ag__.converted_call('matmul', tf, ag__.ConversionOptions(recursive=True, force_conversion=False, optional_features=(), internal_convert_user_code=True), (hidden_1, w_hh), None)
t_2 = ag__.converted_call('matmul', tf, ag__.ConversionOptions(recursive=True, force_conversion=False, optional_features=(), internal_convert_user_code=True), (x_t, w_ih), None)
hidden_1 = ag__.converted_call('tanh', tf.math, ag__.ConversionOptions(recursive=True, force_conversion=False, optional_features=(), internal_convert_user_code=True), (t_1 + t_2 + b,), None)
output_1 = ag__.converted_call(dense, None, ag__.ConversionOptions(recursive=True, force_conversion=False, optional_features=(), internal_convert_user_code=True), (hidden_1,), None)
return hidden_1, output_1
hidden, output = ag__.for_stmt(inputs, None, loop_body, (hidden, output))
do_return = True
retval_ = output
cond = ag__.is_undefined_return(retval_)
def get_state():
return ()
def set_state(_):
pass
def if_true():
retval_ = None
return retval_
def if_false():
return retval_
retval_ = ag__.if_stmt(cond, if_true, if_false, get_state, set_state)
return retval_
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment