Skip to content

Instantly share code, notes, and snippets.

View crhea93's full-sized avatar
💭
Astrophysics!

Carter Lee Rhea crhea93

💭
Astrophysics!
  • Montreal, QC, CANADA
View GitHub Profile
@crhea93
crhea93 / gaussians2.py
Created October 12, 2022 19:18
RIM guassians 2
# Create training, validation, and test sets
train_percentage = 0.7
valid_percentage = 0.9
test_percentage = 1.0
len_X = len(gaussians_initial)
# Training
X_train = gaussians_initial[:int(train_percentage*len_X)]
Y_train = gaussians_final[:int(train_percentage*len_X)]
A_train = powerlaw_conv[:int(train_percentage*len_X)]
N_train = [np.diag(noise_val) for noise_val in noise[:int(train_percentage*len_X)]]
@crhea93
crhea93 / guassians3.py
Last active October 17, 2022 13:44
RIM Guassians 3
# Load model and define hyper parameters
epochs = 100
batch_size = 16
model = RIM(rnn_units1=256, rnn_units2=256, conv_filters=8, kernel_size=2, input_size=n, dimensions=1, t_steps=10, learning_rate=0.005)
# Prepare the training dataset
train_dataset = tf.data.Dataset.from_tensor_slices((X_train, Y_train, A_train, N_train))
train_dataset = train_dataset.batch(batch_size, drop_remainder=True)
train_dataset = train_dataset.prefetch(2)
# Prepare the validation dataset
ysol_valid, training_loss, valid_loss, learning_rates = model.fit(batch_size, epochs, train_dataset, val_dataset)
plt.plot(np.linspace(0, epochs, epochs), training_loss, label='training')
plt.plot(np.linspace(0, epochs, epochs), valid_loss, label='validation')
plt.legend()
plt.show()
plt.plot(np.linspace(0, epochs, epochs), learning_rates[1:], label='learning rate')
plt.legend()
plt.show()
test_dataset = tf.data.Dataset.from_tensor_slices((Y_test, A_test, N_test))
test_dataset = test_dataset.batch(batch_size, drop_remainder=True)
ysol = model(test_dataset)
# Obtain better format
ysol_list = []
for val in ysol:
ysol_ = [val.numpy() for val in val]
ysol_list.append(ysol_)
fig = plt.figure(figsize=(16,8))
plt.plot(np.linspace(-1,1,n), Y_test[-1], label='Noisy', color='C1')
plt.plot(np.linspace(-1,1,n), X_test[-1], label='True', color='C2', linewidth=4)
plt.plot(np.linspace(-1,1,n), ysol_list[-1][-1][-1].reshape(n), label='Predicted', linestyle='dashed', color='C3', linewidth=3)
plt.legend(prop={'size': 20})
plt.ylabel('Normalized y-axis', fontsize=20)
plt.xlabel('X-axis', fontsize=20)
plt.xticks(fontsize = 15)
plt.yticks(fontsize = 15)
plt.title('RIM Example using a Noisy Gaussian')
@crhea93
crhea93 / benchmark_gpu_pixel_to_pixel.py
Created February 19, 2025 13:52
Benchmarking test for pytorch implementation of the pixel_to_pixel algorithm from the reproject packag.e
import time
import torch
import numpy as np
from astropy.wcs import WCS
from astropy.wcs.wcsapi import BaseHighLevelWCS
def pixel_to_pixel(wcs_in: BaseHighLevelWCS, wcs_out: BaseHighLevelWCS, *inputs):
"""