This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Configuration library for experiments.""" | |
from typing import Dict, Any | |
import logging | |
import pprint | |
import sys | |
import argparse | |
logger = logging.getLogger(__name__) | |
parser = argparse.ArgumentParser(description=__doc__, fromfile_prefix_chars="@") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import json | |
import socket | |
from keras.callbacks import ModelCheckpoint | |
class StatefulCheckpoint(ModelCheckpoint): | |
"""Save extra checkpoint data to resume training.""" | |
def __init__(self, weight_file, state_file=None, **kwargs): | |
"""Save the state (epoch etc.) along side weights.""" | |
super().__init__(weight_file, **kwargs) | |
self.state_f = state_file |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""ZeroGRU module.""" | |
import keras.backend as K | |
import keras.layers as L | |
class ZeroGRUCell(L.GRUCell): | |
"""GRU Cell that skips timestep if inputs is zero as well.""" | |
def call(self, inputs, states, training=None): | |
"""Step function of the cell.""" | |
h_tm1 = states[0] # previous output | |
# Check if all inputs are zero for this timestep |