This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
plt.title("TESTING THE MODEL") | |
#TRAINING INSTANCE | |
plt.plot(train_inst[:-1],np.sin(train_inst[:-1]),"bo",markersize=15,alpha=0.5,label="TRAINING INST") | |
#TARGET TO PREDICT | |
plt.plot(train_inst[1:],np.sin(train_inst[1:]),"ko",markersize=8,label="TARGET") | |
#MODEL PREDCTION | |
plt.plot(train_inst[1:],y_pred[0,:,0],"r.",markersize=7,label="PREDICTIONS") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
with tf.Session() as sess: | |
sess.run(init) | |
for iter in range(num_iter): | |
x_batch , y_batch = ts_data.next_batch(batch_size,num_time_steps) | |
sess.run(train,feed_dict={x:x_batch,y:y_batch}) | |
if iter %100==0: | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
train_inst=np.linspace(5,5+ts_data.resolution*(num_time_steps+1),num_time_steps+1) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#MSE | |
loss=tf.reduce_mean(tf.square(outputs-y)) | |
optimizer=tf.train.AdamOptimizer(learning_rate=learning_rate) | |
train=optimizer.minimize(loss) | |
init=tf.global_variables_initializer() | |
x_new=np.sin(np.array(train_inst[:-1].reshape(-1,num_time_steps,num_inputs))) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#ANY RNN CELL TYPE | |
cell=tf.contrib.rnn.OutputProjectionWrapper(tf.contrib.rnn.GRUCell(num_units=num_neurons,activation=tf.nn.relu),output_size=num_outputs) | |
outputs,states=tf.nn.dynamic_rnn(cell,x,dtype=tf.float32) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
num_inputs=1 | |
num_neurons=100 | |
num_outputs=1 | |
learning_rate=0.001 | |
num_iter=5000 | |
batch_size=1 | |
x=tf.placeholder(tf.float32,[None,num_time_steps,num_inputs]) | |
y=tf.placeholder(tf.float32,[None,num_time_steps,num_outputs]) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
plt.plot(ts_data.x_data,ts_data.y_true) | |
plt.plot(ts.flatten()[1:],y1.flatten(),"g*") | |
plt.show() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
num_time_steps=30 | |
y1,y2,ts= ts_data.next_batch(1,num_time_steps,True) | |
print(ts.flatten().shape) | |
plt.plot(ts.flatten()[1:],y1.flatten(),"*") | |
plt.show() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
ts_data=TimeSeriesData(250,0,10) | |
plt.plot(ts_data.x_data,ts_data.y_true) | |
plt.show() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class TimeSeriesData(): | |
def __init__(self,num_points,xmin,xmax): | |
self.xmin=xmin | |
self.xmax=xmax | |
self.num_points=num_points | |
self.resolution=(xmax-xmin)/num_points | |
self.x_data=np.linspace(xmin,xmax,num_points) | |
self.y_true=np.sin(self.x_data) | |
def ret_true(self,x_series): |