Created
June 22, 2018 09:34
-
-
Save Tathagatd96/f106a1478e6549c55aa0ec1fc44abf92 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class TimeSeriesData(): | |
def __init__(self,num_points,xmin,xmax): | |
self.xmin=xmin | |
self.xmax=xmax | |
self.num_points=num_points | |
self.resolution=(xmax-xmin)/num_points | |
self.x_data=np.linspace(xmin,xmax,num_points) | |
self.y_true=np.sin(self.x_data) | |
def ret_true(self,x_series): | |
return np.sin(x_series) | |
def next_batch(self,batch_size,steps,return_batch_ts=False): | |
#Random starting point for batch | |
random_start=np.random.rand(batch_size,1) | |
#Put the random point on the time series | |
ts_start=random_start*(self.xmax-self.xmin-(steps*self.resolution)) #Trick to convert any number into a point in the time series | |
batch_ts=ts_start+np.arange(0.0,steps+1) * self.resolution #steps +1 because we are predicting just one tme step ahead | |
y_batch=np.sin(batch_ts) | |
#Formatting for RNN | |
if return_batch_ts: | |
return y_batch[:,:-1].reshape(-1,steps,1) , y_batch[:,1:].reshape(-1,steps,1), batch_ts | |
else: | |
return y_batch[:,:-1].reshape(-1,steps,1) , y_batch[:,1:].reshape(-1,steps,1) #Returning the sequence shifted over one time step | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment