This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Create training, validation, and test sets | |
train_percentage = 0.7 | |
valid_percentage = 0.9 | |
test_percentage = 1.0 | |
len_X = len(gaussians_initial) | |
# Training | |
X_train = gaussians_initial[:int(train_percentage*len_X)] | |
Y_train = gaussians_final[:int(train_percentage*len_X)] | |
A_train = powerlaw_conv[:int(train_percentage*len_X)] | |
N_train = [np.diag(noise_val) for noise_val in noise[:int(train_percentage*len_X)]] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Load model and define hyper parameters | |
epochs = 100 | |
batch_size = 16 | |
model = RIM(rnn_units1=256, rnn_units2=256, conv_filters=8, kernel_size=2, input_size=n, dimensions=1, t_steps=10, learning_rate=0.005) | |
# Prepare the training dataset | |
train_dataset = tf.data.Dataset.from_tensor_slices((X_train, Y_train, A_train, N_train)) | |
train_dataset = train_dataset.batch(batch_size, drop_remainder=True) | |
train_dataset = train_dataset.prefetch(2) | |
# Prepare the validation dataset |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
ysol_valid, training_loss, valid_loss, learning_rates = model.fit(batch_size, epochs, train_dataset, val_dataset) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
plt.plot(np.linspace(0, epochs, epochs), training_loss, label='training') | |
plt.plot(np.linspace(0, epochs, epochs), valid_loss, label='validation') | |
plt.legend() | |
plt.show() | |
plt.plot(np.linspace(0, epochs, epochs), learning_rates[1:], label='learning rate') | |
plt.legend() | |
plt.show() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
test_dataset = tf.data.Dataset.from_tensor_slices((Y_test, A_test, N_test)) | |
test_dataset = test_dataset.batch(batch_size, drop_remainder=True) | |
ysol = model(test_dataset) | |
# Obtain better format | |
ysol_list = [] | |
for val in ysol: | |
ysol_ = [val.numpy() for val in val] | |
ysol_list.append(ysol_) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
fig = plt.figure(figsize=(16,8)) | |
plt.plot(np.linspace(-1,1,n), Y_test[-1], label='Noisy', color='C1') | |
plt.plot(np.linspace(-1,1,n), X_test[-1], label='True', color='C2', linewidth=4) | |
plt.plot(np.linspace(-1,1,n), ysol_list[-1][-1][-1].reshape(n), label='Predicted', linestyle='dashed', color='C3', linewidth=3) | |
plt.legend(prop={'size': 20}) | |
plt.ylabel('Normalized y-axis', fontsize=20) | |
plt.xlabel('X-axis', fontsize=20) | |
plt.xticks(fontsize = 15) | |
plt.yticks(fontsize = 15) | |
plt.title('RIM Example using a Noisy Gaussian') |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import time | |
import torch | |
import numpy as np | |
from astropy.wcs import WCS | |
from astropy.wcs.wcsapi import BaseHighLevelWCS | |
def pixel_to_pixel(wcs_in: BaseHighLevelWCS, wcs_out: BaseHighLevelWCS, *inputs): | |
""" |
OlderNewer