Created
July 18, 2022 12:12
-
-
Save Lyken17/2b3d87975f683cc63193b26a1df9e64d to your computer and use it in GitHub Desktop.
Resnet50 ofa
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os, sys | |
import os.path as osp | |
import math | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
from torchvision import transforms, datasets | |
from ofa.model_zoo import ofa_net | |
from ofa.elastic_nn.utils import set_running_statistics | |
class ArchTool: | |
def __init__(self, choices_ks=(3, 5, 7), choices_ex=(3, 4, 6), choices_d=(2, 3, 4), | |
max_ks=20, max_ex=20, max_d=5): | |
self.choices_ks = choices_ks | |
self.choices_ex = choices_ex | |
self.choices_d = choices_d | |
self.max_ks = max_ks | |
self.max_ex = max_ex | |
self.max_d = max_d | |
def random(self, serialize=False) -> (list, list, list): | |
return ArchTool.simple_random(self.choices_ks, self.choices_ex, self.choices_d, | |
self.max_ks, self.max_ex, self.max_d, serialize) | |
@staticmethod | |
def simple_random(ks_choices=(3, 5, 7), ex_choices=(3, 4, 6), d_choices=(2, 3, 4), | |
ks_max=20, ex_max=20, d_max=5, serialize=False): | |
assert isinstance(ks_choices, (list, tuple)) | |
assert isinstance(ex_choices, (list, tuple)) | |
assert isinstance(d_choices, (list, tuple)) | |
ks_list = [int(np.random.choice(ks_choices)) for _ in range(ks_max)] | |
ex_list = [int(np.random.choice(ex_choices)) for _ in range(ex_max)] | |
d_list = [int(np.random.choice(d_choices)) for _ in range(d_max)] | |
ks_list, ex_list, d_list = ArchTool.formalize(ks_list, ex_list, d_list) | |
if not serialize: | |
return ks_list, ex_list, d_list | |
else: | |
return ArchTool.serialize(ks_list, ex_list, d_list) | |
def iterate_space(self, serialize=False) -> (list, list, list): | |
for arch in ArchTool.simple_iterate_space( | |
self.choices_ks, self.choices_ex, self.choices_d, | |
self.max_ks, self.max_ex, self.max_d, serialize | |
): | |
yield arch | |
@staticmethod | |
def simple_iterate_space(ks_choices=(3, 5, 7), ex_choices=(3, 4, 6), d_choices=(2, 3, 4), | |
ks_max=20, ex_max=20, d_max=5, serialize=False): | |
assert isinstance(ks_choices, (list, tuple)) | |
assert isinstance(ex_choices, (list, tuple)) | |
assert isinstance(d_choices, (list, tuple)) | |
ks_candidate = [ks_choices for _ in range(ks_max)] | |
ex_candidate = [ex_choices for _ in range(ex_max)] | |
d_candidate = [d_choices for _ in range(d_max)] | |
for config in itertools.product(*(ks_candidate + ex_candidate + d_candidate)): | |
ks_list = config[:ks_max] | |
ex_list = config[ks_max:ks_max + ex_max] | |
d_list = config[ks_max + ex_max:] | |
ks_list, ex_list, d_list = ArchTool.formalize(ks_list, ex_list, d_list) | |
if not serialize: | |
yield ks_list, ex_list, d_list | |
else: | |
yield ArchTool.serialize(ks_list, ex_list, d_list) | |
@staticmethod | |
def formalize(_ks_list: list, _ex_list: list, _d_list: list) -> (list, list, list): | |
ks_list, ex_list, d_list = list(_ks_list), list(_ex_list), list(_d_list) | |
# ks and ex between (d, max_d) is meaningless. Fill 0 to avoid redundancy | |
start = 0 | |
end = 4 | |
for d in d_list: | |
for j in range(start + d, end): | |
ks_list[j] = 0 | |
ex_list[j] = 0 | |
start += 4 | |
end += 4 | |
return ks_list, ex_list, d_list | |
@staticmethod | |
def serialize(ks_list: list, ex_list: list, d_list: list) -> str: | |
assert len(ks_list) == 20, "Kernel size list can only contain 20 numbers." | |
assert len(ex_list) == 20, "Expansion ratio list can only contain 20 numbers." | |
assert len(d_list) == 5, "Depth list can only contain 5 numbers." | |
ks_list, ex_list, d_list = ArchTool.formalize(ks_list, ex_list, d_list) | |
ks_str = "%s:%s" % ("ks", ",".join([str(_) for _ in ks_list])) | |
ex_str = "%s:%s" % ("ex", ",".join([str(_) for _ in ex_list])) | |
d_str = "%s:%s" % ("d", ",".join([str(_) for _ in d_list])) | |
return "-".join([ks_str, ex_str, d_str]) | |
@staticmethod | |
def deserialize(cfg_str: str) : | |
ks_str, ex_str, d_str = cfg_str.strip().split("-") | |
ks_list = [int(_) for _ in ks_str.split(":")[-1].split(",")] | |
ex_list = [float(_) for _ in ex_str.split(":")[-1].split(",")] | |
d_list = [int(_) for _ in d_str.split(":")[-1].split(",")] | |
assert len(ks_list) == 20, "Kernel size list can only contain 20 numbers." | |
assert len(ex_list) == 20, "Expansion ratio list can only contain 20 numbers." | |
assert len(d_list) == 5, "Depth list can only contain 5 numbers." | |
return ks_list, ex_list, d_list | |
class Slave: | |
def __init__(self, resolution=224): | |
self.resolution = resolution | |
self.cached_train_loader = None | |
self.cached_valid_loader = None | |
def get_train_loader(self, path="imagenet", image_size=224, num_images=5000, batch_size=200, num_workers=5, | |
use_cache=False): | |
if use_cache and self.cached_train_loader is not None: | |
print("Using data cached in memory for train") | |
return self.cached_train_loader | |
dataset = datasets.ImageFolder( | |
osp.join(path, 'train'), | |
transforms.Compose([ | |
transforms.RandomResizedCrop(image_size), | |
transforms.RandomHorizontalFlip(), | |
transforms.ColorJitter(brightness=32. / 255., saturation=0.5), | |
transforms.ToTensor(), | |
transforms.Normalize( | |
mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225] | |
), | |
]) | |
) | |
chosen_indexes = np.random.choice(list(range(len(dataset))), num_images) | |
sub_sampler = torch.utils.data.sampler.SubsetRandomSampler(chosen_indexes) | |
data_loader = torch.utils.data.DataLoader( | |
dataset, | |
sampler=sub_sampler, | |
batch_size=batch_size, | |
num_workers=num_workers, | |
pin_memory=True, | |
drop_last=False, | |
) | |
print(f"Using local data loader for train {len(dataset)}") | |
if use_cache: | |
self.cached_train_loader = tuple(_ for _ in data_loader) | |
return data_loader | |
def get_valid_loader(self, val_path="imagenet/val", | |
image_size=224, batch_size=256, num_workers=10, print_freq=10, | |
use_cache=True): | |
if use_cache and self.cached_valid_loader is not None: | |
print("Using data cached in memory for val") | |
return self.cached_valid_loader | |
val_transform = transforms.Compose([ | |
transforms.Resize(int(math.ceil(image_size / 0.875))), | |
transforms.CenterCrop(image_size), | |
transforms.ToTensor(), | |
transforms.Normalize( | |
mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225] | |
), | |
]) | |
valset = datasets.ImageFolder(val_path, val_transform) | |
val_loader = torch.utils.data.DataLoader(valset, | |
batch_size=batch_size, num_workers=num_workers, pin_memory=True, | |
drop_last=False) | |
print(f"Using local data loader for val {len(valset)}") | |
if use_cache: | |
self.cached_valid_loader = tuple(_ for _ in val_loader) | |
return val_loader | |
def run(self, gpu=None, net_id="ofa_mbv3_d234_e346_k357_w1.0", net_cfg=None, | |
use_cache=False, batch_size=256): | |
sys.path.append("/NFS/home/ligeng/Workspace/mjt-dev") | |
from utils import validate | |
device = "cpu" | |
if torch.cuda.is_available(): | |
device = "cuda" | |
torch.cuda.set_device(gpu) | |
ofa_network = ofa_net(net_id, pretrained=True) | |
ks_list, ex_list, d_list = ArchTool.deserialize(net_cfg) | |
ofa_network.set_active_subnet(ks=ks_list, e=ex_list, d=d_list) | |
manual_subnet = ofa_network.get_active_subnet(preserve_weight=True) | |
model = manual_subnet.to(device) | |
train_data_loader = self.get_train_loader(path="/dataset/imagenet", | |
image_size=self.resolution, | |
use_cache=use_cache, | |
batch_size=batch_size) | |
val_loader = self.get_valid_loader(val_path="/dataset/imagenet/val", | |
image_size=self.resolution, | |
use_cache=use_cache, | |
batch_size=batch_size) | |
print("Calibrating BN") | |
set_running_statistics(model, train_data_loader) | |
print(f"Validating Top-1 acc {net_cfg} {net_id}") | |
top1 = validate(val_loader, model, criterion=nn.CrossEntropyLoss(), gpu=gpu, print_freq=None) | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
return top1.item() | |
def random_mbv2_space(): | |
ks_list = [int(np.random.choice([3, 5, 7 ])) for i in range(20)] | |
ex_list = [float(np.random.choice([3, 4, 6])) for i in range(20)] | |
d_list = [int(np.random.choice([2, 3, 4])) for i in range(5)] | |
arch_str = ArchTool.serialize(ks_list, ex_list, d_list) | |
return ks_list, ex_list, d_list, arch_str | |
def random_resnet50_space(): | |
ks_list = [int(np.random.choice([3, ])) for i in range(20)] | |
ex_list = [float(np.random.choice([0.2, 0.25, 0.35])) for i in range(20)] | |
d_list = [int(np.random.choice([0, 1, 2])) for i in range(5)] | |
arch_str = ArchTool.serialize(ks_list, ex_list, d_list) | |
return ks_list, ex_list, d_list, arch_str | |
def main(): | |
ks_list, ex_list, d_list, cfg_str = random_resnet50_space() | |
worker = Slave() | |
top1 = worker.run("cuda:0", net_id="ofa_resnet50", net_cfg=cfg_str) | |
print(top1) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment