Skip to content

Instantly share code, notes, and snippets.

@nilsleh
Created February 9, 2023 21:19
Show Gist options
  • Save nilsleh/5f6c175e12169015e68d4f1b66b0349a to your computer and use it in GitHub Desktop.
Save nilsleh/5f6c175e12169015e68d4f1b66b0349a to your computer and use it in GitHub Desktop.
import torch
import timm
from torchgeo.datasets import EuroSAT
import numpy as np
from sklearn.linear_model import LogisticRegression
from torch.utils.data import DataLoader
import torch
from tqdm import tqdm
from torchvision.transforms import Normalize
from torchvision.models.feature_extraction import create_feature_extractor
checkpoint = torch.load("/home/nils/Downloads/seco_resnet18_1m(2).ckpt", map_location="cpu")
state_dict = checkpoint["state_dict"]
new_dict = {}
for key, val in checkpoint["state_dict"].items():
if key in ["encoder_q.0.weight"]:
new_dict["conv1.weight"] = val
if key in ["head_q.2.2.weight"]:
new_dict["fc.weight"] = val
elif key.startswith("encoder_k"):
continue
else:
new_dict[key] = val
path = "seco_resnet18.pth"
torch.save(new_dict, path)
model = timm.create_model("resnet18", pretrained=True)
# model.load_state_dict(torch.load(path), strict=False)
model = model.eval()
# Caleb's script
num_features = 512
model = create_feature_extractor(model, return_nodes=["global_pool"])
band_means = torch.tensor(
[
1354.40546513,
1118.24399958,
1042.92983953,
947.62620298,
1199.47283961,
1999.79090914,
2369.22292565,
2296.82608323,
732.08340178,
12.11327804,
1819.01027855,
1118.92391149,
2594.14080798,
]
)
band_stds = torch.tensor(
[
245.71762908,
333.00778264,
395.09249139,
593.75055589,
566.4170017,
861.18399006,
1086.63139075,
1117.98170791,
404.91978886,
4.77584468,
1002.58768311,
761.30323499,
1231.58581042,
]
)
band_means = band_means[[3,2,1]]
band_stds = band_stds[[3,2,1]]
min_value = (band_means - 2 * band_stds).unsqueeze(1).unsqueeze(2)
max_value = (band_means + 2 * band_stds).unsqueeze(1).unsqueeze(2)
norm = Normalize(band_means, band_stds)
# def preprocess(sample):
# img = sample["image"].float()
# img = (img - min_value) / (max_value - min_value)
# sample["image"] = torch.clip(img, 0, 1)
# return sample
def preprocess(sample):
# sample["image"] = (sample["image"].float() / 10000.0)
sample["image"] = norm(sample["image"])
return sample
train_ds = EuroSAT(
root="data/EuroSATallBands/",
split="train",
bands=EuroSAT.BAND_SETS["rgb"],
transforms=preprocess,
)
train_dl = DataLoader(train_ds, batch_size=32, shuffle=False, num_workers=6)
test_ds = EuroSAT(
root="data/EuroSATallBands/",
split="test",
bands=EuroSAT.BAND_SETS["rgb"],
transforms=preprocess,
)
test_dl = DataLoader(test_ds, batch_size=32, shuffle=False, num_workers=6)
def extract_features(model, dataloader, device):
x_all = []
y_all = []
for batch in tqdm(dataloader):
images = batch["image"].to(device)
labels = batch["label"].numpy()
with torch.inference_mode():
features = model(images)['global_pool'].cpu().numpy()
x_all.append(features)
y_all.append(labels)
x_all = np.concatenate(x_all, axis=0)
y_all = np.concatenate(y_all, axis=0)
return x_all, y_all
x_train, y_train = extract_features(model, train_dl, "cpu")
x_test, y_test = extract_features(model, test_dl, "cpu")
linear_model = LogisticRegression(C=50.0, max_iter=1000)
linear_model.fit(x_train, y_train)
print(linear_model.score(x_test, y_test))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment