Created
April 14, 2020 21:40
-
-
Save piEsposito/ff64a84fbe5641ebe0746bba228629c6 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def pred_stock_future(X_test, | |
future_length, | |
sample_nbr=10): | |
#sorry for that, window_size is a global variable, and so are X_train and Xs | |
global window_size | |
global X_train | |
global Xs | |
global scaler | |
#creating auxiliar variables for future prediction | |
preds_test = [] | |
test_begin = X_test[0:1, :, :] | |
test_deque = deque(test_begin[0,:,0].tolist(), maxlen=window_size) | |
idx_pred = np.arange(len(X_train), len(Xs)) | |
#predict it and append to list | |
for i in range(len(X_test)): | |
#print(i) | |
as_net_input = torch.tensor(test_deque).unsqueeze(0).unsqueeze(2) | |
pred = [net(as_net_input).cpu().item() for i in range(sample_nbr)] | |
test_deque.append(torch.tensor(pred).mean().cpu().item()) | |
preds_test.append(pred) | |
if i % future_length == 0: | |
#our inptus become the i index of our X_test | |
#That tweak just helps us with shape issues | |
test_begin = X_test[i:i+1, :, :] | |
test_deque = deque(test_begin[0,:,0].tolist(), maxlen=window_size) | |
#preds_test = np.array(preds_test).reshape(-1, 1) | |
#preds_test_unscaled = scaler.inverse_transform(preds_test) | |
return idx_pred, preds_test |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment