Created
September 6, 2016 14:37
-
-
Save nasimrahaman/fbf056757c454648a4d1573674dd3b85 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import yaml | |
import numpy as np | |
import os | |
from theano import config | |
class relay(object): | |
def __init__(self, switches, ymlfile, callevery=1): | |
""" | |
Given the path to a YAML file (`ymlfile`) and a dictionary `switches` having the format | |
{'name': theano-shared-variable, ...}, this class' read method sets the value of the theano shared variable | |
with the one found in the corresponding field of the YAML file. | |
:type ymlfile: str | |
:param ymlfile: Path to YAML file where the parameters are stored. | |
:type switches: dict | |
:param switches: Should be {'name': sharedvar, ...}, where sharedvar is a theano shared variable and 'name' is | |
the corresponding access key in the YAML file. | |
""" | |
# Meta | |
self.switches = switches | |
self.ymlfile = ymlfile | |
self.lastmodified = 0. | |
def read(self): | |
# Check if there are changes to the read | |
filehaschanged = os.stat(self.ymlfile).st_mtime != self.lastmodified | |
# Update lastmodified timestamp | |
if filehaschanged: | |
self.lastmodified = os.stat(self.ymlfile).st_mtime | |
else: | |
return | |
# Read from file | |
with open(self.ymlfile, 'r') as f: | |
update = yaml.load(f) | |
# Update switches | |
for switchname, switchvar in self.switches.items(): | |
# Fetch | |
if switchname in update.keys(): | |
# Check if update needs to be eval-ed | |
if isinstance(update[switchname], str) and update[switchname].startswith('np.'): | |
switchvarval = eval(update[switchname]) | |
else: | |
switchvarval = getattr(np, config.floatX)(update[switchname]) | |
# Set switch variable | |
switchvar.set_value(switchvarval) | |
return | |
def __call__(self, *args, **kwargs): | |
self.read() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment