Created
September 29, 2015 23:38
-
-
Save JamesPHoughton/c096773921c82d73cbe5 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class MinModel(object): | |
########## boilerplate stuff from the existing pysd ######### | |
def __init__(self): | |
self._stocknames = [name[:-5] for name in dir(self) if name[-5:] == '_init'] | |
self._stocknames.sort() #inplace | |
self._dfuncs = [getattr(self, 'd%s_dt'%name) for name in self._stocknames] | |
self.state = dict(zip(self._stocknames, [None]*len(self._stocknames))) | |
self.reset_state() | |
self.functions = functions.Functions(self) | |
def reset_state(self): | |
"""Sets the model state to the state described in the model file. """ | |
self.t = self.initial_time() #set the initial time | |
retry_flag = False | |
for key in self.state.keys(): | |
try: | |
self.state[key] = eval('self.'+key+'_init()') #set the initial state | |
except TypeError: | |
retry_flag = True | |
if retry_flag: | |
self.reset_state() #potential for infinite loop! | |
########### Stuff we have to modify to make subscripts work ######### | |
def d_dt(self, state_vector, t): | |
"""The primary purpose of this function is to interact with the integrator. | |
It takes a state vector, sets the state of the system based on that vector, | |
and returns a derivative of the state vector | |
""" | |
self.set_state(state_vector) | |
self.t = t | |
derivative_vector = [] | |
for func in self._dfuncs: | |
derivative_vector += list(func()) | |
return derivative_vector | |
def set_state(self, state_vector): | |
i = 0 | |
for key in self._stocknames: | |
if isinstance(self.state[key], np.ndarray): | |
size = self.state[key].size | |
elements = state_vector[i:i+size] | |
shape = self.state[key].shape | |
self.state[key] = np.array(elements).reshape(shape) | |
i += size | |
else: | |
self.state[key] = state_vector[i] | |
i += 1 | |
def get_state(self): | |
#if we keep this, we should make it fully a list comprehension | |
state_vector = [] | |
for item in [self.state[key] for key in self._stocknames]: | |
if isinstance(item, np.ndarray): | |
state_vector += list(item.flatten()) | |
else: | |
state_vector += list(item) | |
return state_vector | |
######### model specific components (that go in the model file) | |
suba_list = ['suba1', 'suba2', 'suba3'] | |
subb_list = ['suba2', 'subb2'] | |
suba_index = dict(zip(suba_list, range(len(suba_list)))) | |
subb_index = dict(zip(subb_list, range(len(subb_list)))) | |
def stock(self, suba, subb): | |
return self.state['stock'][self.suba_index[suba]][self.subb_index[subb]] | |
def stock_init(self): | |
return np.array([[1,1],[1,1],[1,1]]) | |
def dstock_dt(self): | |
return [self.flow(suba, subb) for suba, subb in itertools.product(self.suba_list, self.subb_list)] | |
def constant(self, suba, subb): | |
return self.constant.values[self.suba_index[suba]][self.subb_index[subb]] | |
constant.values = np.array([[1,2],[3,4],[5,6]]) | |
def flow(self, suba, subb): | |
return self.constant(suba, subb) * self.stock(suba, subb) | |
def initial_time(self): | |
return 0 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment