This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from minigrid.wrappers import RGBImgPartialObsWrapper, ImgObsWrapper | |
from stable_baselines3.common.monitor import Monitor | |
from gymnasium.utils.play import play | |
from utils.utils import move_to_pos, get_pos_from_int, turn_and_explore # get_pos_from_int, turn_and_explore | |
from minigrid.core.actions import Actions | |
second_task = gym.make("MiniGrid-BlockedUnlockPickup-v0", render_mode = 'human') | |
# play(second_task, | |
# keys_to_action={ | |
# "w": np.int64(2), | |
# "a": np.int64(0), |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import gymnasium as gym | |
import minigrid | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import sys | |
sys.path.append("/home/mb230/rice_coursework/f24/comp552/comp-552-assignment-backup/assignment5") | |
import wandb | |
run = wandb.init( | |
project="comp552-a5", monitor_gym = True, sync_tensorboard=True | |
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
sim_mat_dims = (len(dl.dataset), len(dl.dataset)) | |
print("Dimensions of similarity matrix is", sim_mat_dims) | |
print("Making empty matrix to store similarities ......") | |
feat_mat = np.empty(sim_mat_dims, dtype=np.float32) | |
loss_fn = nn.CrossEntropyLoss(reduction='mean').to(self.device) | |
for idx, data in tqdm(enumerate(dl)): | |
loss_val = loss_fn(net(data[0].to(self.device)), data[1].to(self.device)) | |
grad_list = torch.autograd.grad(loss_val, inputs = [p for p in net.parameters() if p.requires_grad]) | |
feats_outer = [t.flatten() for t in grad_list] | |
feats_outer = torch.cat(feats_outer) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
[16:55:34] ----- Submarine logging started ----- | |
[16:55:34] Started new execution run, signed 32-bit element types, signed 64-bit matrix indices .... | |
[16:55:34] Types: set elements signed 32-bit, set sizes signed 32-bit, set indices signed 32-bit, set iters unsigned 32-bit, matrix indices signed 64-bit. | |
[16:55:34] Using 22 threads for general operations. | |
[16:55:34] Command line: smraiz -flfilename /mnt/disks/spinning_scratch0/smrai-container-documentation/src/saved_results/tinyimagenet/tinyimagenet_convnetd4_1_features_ffcv_False_simeuclid_sim_or_dist.npy -sumsize 1 -cloglevel trace -floglevel trace -loglevel trace -clogtimestamps T -nochecks | |
[16:55:34] Loading FL matrix. | |
[16:55:34] RNPHR: Reading numpy file v1.0 with header size 70. | |
[16:55:34] RNPHR: Finished header of v1.0 numpy file, a 100000x100000 32-bit float (f4) matrix, element byte length 4, fortran false, endian little, endian N/A false. | |
[16:55:34] RMNF: allocating and reading 100000x100000 numpy matrix, skip_type_cast=true, fortran_order = false, n |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import time | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torchvision import datasets, transforms | |
import torchvision | |
from torch.utils.data import Dataset | |
from scipy.ndimage.interpolation import rotate as scipyrotate | |
import sys |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from utils.evaluator_utils import EvaluatorUtils | |
ds_train = datasets.CIFAR10('data', train=True, download=True, transform=transform) | |
ds_test = datasets.CIFAR10('data', train=False, download=True, transform=transform) | |
images_all = [torch.unsqueeze(ds_train[i][0], dim=0) for i in range(len(ds_train))] | |
labels_all = [ds_train[i][1] for i in range(len(ds_train))] | |
class_pos_list = [] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import time | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torchvision import datasets, transforms | |
import torchvision | |
from torch.utils.data import Dataset | |
from scipy.ndimage.interpolation import rotate as scipyrotate | |
import sys |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import time | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import os | |
import kornia as K | |
import tqdm | |
from torch.utils.data import Dataset | |
from torchvision import datasets, transforms |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from turtle import forward | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch | |
from typing import Dict, Iterable, Callable | |
import torchvision | |
from models.pretrained_implementations import resnet18_pret | |
from models.conv_iResNet import conv_iResNet | |
# Acknowledgement to | |
# https://github.com/kuangliu/pytorch-cifar, |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
DATA_DIR='./data/datasets' | |
MAX_DEPTH=15 | |
MAX_NODES=30 | |
SEARCH_METHOD=bfs | |
MODEL=LSTM | |
NUM_EPOCHS_MENTION_ONLY=1 | |
NUM_EPOCHS_WITH_COHERENCE=30 | |
BATCH_SIZE=32 |
NewerOlder