This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def create_timestamps_ds(series, | |
timestep_size=window_size): | |
time_stamps = [] | |
labels = [] | |
aux_deque = deque(maxlen=timestep_size) | |
#starting the timestep deque | |
for i in range(timestep_size): | |
aux_deque.append(0) | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@variational_estimator | |
class NN(nn.Module): | |
def __init__(self): | |
super(NN, self).__init__() | |
self.lstm_1 = BayesianLSTM(1, 10) | |
self.linear = nn.Linear(10, 1) | |
def forward(self, x): | |
x_, _ = self.lstm_1(x) | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Xs, ys = create_timestamps_ds(close_prices) | |
X_train, X_test, y_train, y_test = train_test_split(Xs, | |
ys, | |
test_size=.25, | |
random_state=42, | |
shuffle=False) | |
ds = torch.utils.data.TensorDataset(X_train, y_train) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
iteration = 0 | |
for epoch in range(10): | |
for i, (datapoints, labels) in enumerate(dataloader_train): | |
optimizer.zero_grad() | |
loss = net.sample_elbo(inputs=datapoints, | |
labels=labels, | |
criterion=criterion, | |
sample_nbr=3) | |
loss.backward() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
original = close_prices_unscaled[1:][window_size:] | |
df_pred = pd.DataFrame(original) | |
df_pred["Date"] = df.Date | |
df["Date"] = pd.to_datetime(df_pred["Date"]) | |
df_pred = df_pred.reset_index() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def pred_stock_future(X_test, | |
future_length, | |
sample_nbr=10): | |
#sorry for that, window_size is a global variable, and so are X_train and Xs | |
global window_size | |
global X_train | |
global Xs | |
global scaler | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def get_confidence_intervals(preds_test, ci_multiplier): | |
global scaler | |
preds_test = torch.tensor(preds_test) | |
pred_mean = preds_test.mean(1) | |
pred_std = preds_test.std(1).detach().cpu().numpy() | |
pred_std = torch.tensor((pred_std)) | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
future_length=7 | |
sample_nbr=4 | |
ci_multiplier=10 | |
idx_pred, preds_test = pred_stock_future(X_test, future_length, sample_nbr) | |
pred_mean_unscaled, upper_bound_unscaled, lower_bound_unscaled = get_confidence_intervals(preds_test, | |
ci_multiplier) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
y = np.array(df.Close[-750:]).reshape(-1, 1) | |
under_upper = upper_bound_unscaled > y | |
over_lower = lower_bound_unscaled < y | |
total = (under_upper == over_lower) | |
print("{} our predictions are in our confidence interval".format(np.mean(total))) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
params = {"ytick.color" : "w", | |
"xtick.color" : "w", | |
"axes.labelcolor" : "w", | |
"axes.edgecolor" : "w"} | |
plt.rcParams.update(params) | |
plt.title("IBM Stock prices", color="white") | |
plt.plot(df_pred.index, | |
df_pred.Close, |