Created
February 9, 2023 21:19
-
-
Save nilsleh/5f6c175e12169015e68d4f1b66b0349a to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import timm | |
from torchgeo.datasets import EuroSAT | |
import numpy as np | |
from sklearn.linear_model import LogisticRegression | |
from torch.utils.data import DataLoader | |
import torch | |
from tqdm import tqdm | |
from torchvision.transforms import Normalize | |
from torchvision.models.feature_extraction import create_feature_extractor | |
checkpoint = torch.load("/home/nils/Downloads/seco_resnet18_1m(2).ckpt", map_location="cpu") | |
state_dict = checkpoint["state_dict"] | |
new_dict = {} | |
for key, val in checkpoint["state_dict"].items(): | |
if key in ["encoder_q.0.weight"]: | |
new_dict["conv1.weight"] = val | |
if key in ["head_q.2.2.weight"]: | |
new_dict["fc.weight"] = val | |
elif key.startswith("encoder_k"): | |
continue | |
else: | |
new_dict[key] = val | |
path = "seco_resnet18.pth" | |
torch.save(new_dict, path) | |
model = timm.create_model("resnet18", pretrained=True) | |
# model.load_state_dict(torch.load(path), strict=False) | |
model = model.eval() | |
# Caleb's script | |
num_features = 512 | |
model = create_feature_extractor(model, return_nodes=["global_pool"]) | |
band_means = torch.tensor( | |
[ | |
1354.40546513, | |
1118.24399958, | |
1042.92983953, | |
947.62620298, | |
1199.47283961, | |
1999.79090914, | |
2369.22292565, | |
2296.82608323, | |
732.08340178, | |
12.11327804, | |
1819.01027855, | |
1118.92391149, | |
2594.14080798, | |
] | |
) | |
band_stds = torch.tensor( | |
[ | |
245.71762908, | |
333.00778264, | |
395.09249139, | |
593.75055589, | |
566.4170017, | |
861.18399006, | |
1086.63139075, | |
1117.98170791, | |
404.91978886, | |
4.77584468, | |
1002.58768311, | |
761.30323499, | |
1231.58581042, | |
] | |
) | |
band_means = band_means[[3,2,1]] | |
band_stds = band_stds[[3,2,1]] | |
min_value = (band_means - 2 * band_stds).unsqueeze(1).unsqueeze(2) | |
max_value = (band_means + 2 * band_stds).unsqueeze(1).unsqueeze(2) | |
norm = Normalize(band_means, band_stds) | |
# def preprocess(sample): | |
# img = sample["image"].float() | |
# img = (img - min_value) / (max_value - min_value) | |
# sample["image"] = torch.clip(img, 0, 1) | |
# return sample | |
def preprocess(sample): | |
# sample["image"] = (sample["image"].float() / 10000.0) | |
sample["image"] = norm(sample["image"]) | |
return sample | |
train_ds = EuroSAT( | |
root="data/EuroSATallBands/", | |
split="train", | |
bands=EuroSAT.BAND_SETS["rgb"], | |
transforms=preprocess, | |
) | |
train_dl = DataLoader(train_ds, batch_size=32, shuffle=False, num_workers=6) | |
test_ds = EuroSAT( | |
root="data/EuroSATallBands/", | |
split="test", | |
bands=EuroSAT.BAND_SETS["rgb"], | |
transforms=preprocess, | |
) | |
test_dl = DataLoader(test_ds, batch_size=32, shuffle=False, num_workers=6) | |
def extract_features(model, dataloader, device): | |
x_all = [] | |
y_all = [] | |
for batch in tqdm(dataloader): | |
images = batch["image"].to(device) | |
labels = batch["label"].numpy() | |
with torch.inference_mode(): | |
features = model(images)['global_pool'].cpu().numpy() | |
x_all.append(features) | |
y_all.append(labels) | |
x_all = np.concatenate(x_all, axis=0) | |
y_all = np.concatenate(y_all, axis=0) | |
return x_all, y_all | |
x_train, y_train = extract_features(model, train_dl, "cpu") | |
x_test, y_test = extract_features(model, test_dl, "cpu") | |
linear_model = LogisticRegression(C=50.0, max_iter=1000) | |
linear_model.fit(x_train, y_train) | |
print(linear_model.score(x_test, y_test)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment