Skip to content

Instantly share code, notes, and snippets.

@Lyken17
Created July 18, 2022 12:12
Show Gist options
  • Save Lyken17/2b3d87975f683cc63193b26a1df9e64d to your computer and use it in GitHub Desktop.
Save Lyken17/2b3d87975f683cc63193b26a1df9e64d to your computer and use it in GitHub Desktop.
Resnet50 ofa
import os, sys
import os.path as osp
import math
import numpy as np
import torch
import torch.nn as nn
from torchvision import transforms, datasets
from ofa.model_zoo import ofa_net
from ofa.elastic_nn.utils import set_running_statistics
class ArchTool:
def __init__(self, choices_ks=(3, 5, 7), choices_ex=(3, 4, 6), choices_d=(2, 3, 4),
max_ks=20, max_ex=20, max_d=5):
self.choices_ks = choices_ks
self.choices_ex = choices_ex
self.choices_d = choices_d
self.max_ks = max_ks
self.max_ex = max_ex
self.max_d = max_d
def random(self, serialize=False) -> (list, list, list):
return ArchTool.simple_random(self.choices_ks, self.choices_ex, self.choices_d,
self.max_ks, self.max_ex, self.max_d, serialize)
@staticmethod
def simple_random(ks_choices=(3, 5, 7), ex_choices=(3, 4, 6), d_choices=(2, 3, 4),
ks_max=20, ex_max=20, d_max=5, serialize=False):
assert isinstance(ks_choices, (list, tuple))
assert isinstance(ex_choices, (list, tuple))
assert isinstance(d_choices, (list, tuple))
ks_list = [int(np.random.choice(ks_choices)) for _ in range(ks_max)]
ex_list = [int(np.random.choice(ex_choices)) for _ in range(ex_max)]
d_list = [int(np.random.choice(d_choices)) for _ in range(d_max)]
ks_list, ex_list, d_list = ArchTool.formalize(ks_list, ex_list, d_list)
if not serialize:
return ks_list, ex_list, d_list
else:
return ArchTool.serialize(ks_list, ex_list, d_list)
def iterate_space(self, serialize=False) -> (list, list, list):
for arch in ArchTool.simple_iterate_space(
self.choices_ks, self.choices_ex, self.choices_d,
self.max_ks, self.max_ex, self.max_d, serialize
):
yield arch
@staticmethod
def simple_iterate_space(ks_choices=(3, 5, 7), ex_choices=(3, 4, 6), d_choices=(2, 3, 4),
ks_max=20, ex_max=20, d_max=5, serialize=False):
assert isinstance(ks_choices, (list, tuple))
assert isinstance(ex_choices, (list, tuple))
assert isinstance(d_choices, (list, tuple))
ks_candidate = [ks_choices for _ in range(ks_max)]
ex_candidate = [ex_choices for _ in range(ex_max)]
d_candidate = [d_choices for _ in range(d_max)]
for config in itertools.product(*(ks_candidate + ex_candidate + d_candidate)):
ks_list = config[:ks_max]
ex_list = config[ks_max:ks_max + ex_max]
d_list = config[ks_max + ex_max:]
ks_list, ex_list, d_list = ArchTool.formalize(ks_list, ex_list, d_list)
if not serialize:
yield ks_list, ex_list, d_list
else:
yield ArchTool.serialize(ks_list, ex_list, d_list)
@staticmethod
def formalize(_ks_list: list, _ex_list: list, _d_list: list) -> (list, list, list):
ks_list, ex_list, d_list = list(_ks_list), list(_ex_list), list(_d_list)
# ks and ex between (d, max_d) is meaningless. Fill 0 to avoid redundancy
start = 0
end = 4
for d in d_list:
for j in range(start + d, end):
ks_list[j] = 0
ex_list[j] = 0
start += 4
end += 4
return ks_list, ex_list, d_list
@staticmethod
def serialize(ks_list: list, ex_list: list, d_list: list) -> str:
assert len(ks_list) == 20, "Kernel size list can only contain 20 numbers."
assert len(ex_list) == 20, "Expansion ratio list can only contain 20 numbers."
assert len(d_list) == 5, "Depth list can only contain 5 numbers."
ks_list, ex_list, d_list = ArchTool.formalize(ks_list, ex_list, d_list)
ks_str = "%s:%s" % ("ks", ",".join([str(_) for _ in ks_list]))
ex_str = "%s:%s" % ("ex", ",".join([str(_) for _ in ex_list]))
d_str = "%s:%s" % ("d", ",".join([str(_) for _ in d_list]))
return "-".join([ks_str, ex_str, d_str])
@staticmethod
def deserialize(cfg_str: str) :
ks_str, ex_str, d_str = cfg_str.strip().split("-")
ks_list = [int(_) for _ in ks_str.split(":")[-1].split(",")]
ex_list = [float(_) for _ in ex_str.split(":")[-1].split(",")]
d_list = [int(_) for _ in d_str.split(":")[-1].split(",")]
assert len(ks_list) == 20, "Kernel size list can only contain 20 numbers."
assert len(ex_list) == 20, "Expansion ratio list can only contain 20 numbers."
assert len(d_list) == 5, "Depth list can only contain 5 numbers."
return ks_list, ex_list, d_list
class Slave:
def __init__(self, resolution=224):
self.resolution = resolution
self.cached_train_loader = None
self.cached_valid_loader = None
def get_train_loader(self, path="imagenet", image_size=224, num_images=5000, batch_size=200, num_workers=5,
use_cache=False):
if use_cache and self.cached_train_loader is not None:
print("Using data cached in memory for train")
return self.cached_train_loader
dataset = datasets.ImageFolder(
osp.join(path, 'train'),
transforms.Compose([
transforms.RandomResizedCrop(image_size),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(brightness=32. / 255., saturation=0.5),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
),
])
)
chosen_indexes = np.random.choice(list(range(len(dataset))), num_images)
sub_sampler = torch.utils.data.sampler.SubsetRandomSampler(chosen_indexes)
data_loader = torch.utils.data.DataLoader(
dataset,
sampler=sub_sampler,
batch_size=batch_size,
num_workers=num_workers,
pin_memory=True,
drop_last=False,
)
print(f"Using local data loader for train {len(dataset)}")
if use_cache:
self.cached_train_loader = tuple(_ for _ in data_loader)
return data_loader
def get_valid_loader(self, val_path="imagenet/val",
image_size=224, batch_size=256, num_workers=10, print_freq=10,
use_cache=True):
if use_cache and self.cached_valid_loader is not None:
print("Using data cached in memory for val")
return self.cached_valid_loader
val_transform = transforms.Compose([
transforms.Resize(int(math.ceil(image_size / 0.875))),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
),
])
valset = datasets.ImageFolder(val_path, val_transform)
val_loader = torch.utils.data.DataLoader(valset,
batch_size=batch_size, num_workers=num_workers, pin_memory=True,
drop_last=False)
print(f"Using local data loader for val {len(valset)}")
if use_cache:
self.cached_valid_loader = tuple(_ for _ in val_loader)
return val_loader
def run(self, gpu=None, net_id="ofa_mbv3_d234_e346_k357_w1.0", net_cfg=None,
use_cache=False, batch_size=256):
sys.path.append("/NFS/home/ligeng/Workspace/mjt-dev")
from utils import validate
device = "cpu"
if torch.cuda.is_available():
device = "cuda"
torch.cuda.set_device(gpu)
ofa_network = ofa_net(net_id, pretrained=True)
ks_list, ex_list, d_list = ArchTool.deserialize(net_cfg)
ofa_network.set_active_subnet(ks=ks_list, e=ex_list, d=d_list)
manual_subnet = ofa_network.get_active_subnet(preserve_weight=True)
model = manual_subnet.to(device)
train_data_loader = self.get_train_loader(path="/dataset/imagenet",
image_size=self.resolution,
use_cache=use_cache,
batch_size=batch_size)
val_loader = self.get_valid_loader(val_path="/dataset/imagenet/val",
image_size=self.resolution,
use_cache=use_cache,
batch_size=batch_size)
print("Calibrating BN")
set_running_statistics(model, train_data_loader)
print(f"Validating Top-1 acc {net_cfg} {net_id}")
top1 = validate(val_loader, model, criterion=nn.CrossEntropyLoss(), gpu=gpu, print_freq=None)
if torch.cuda.is_available():
torch.cuda.empty_cache()
return top1.item()
def random_mbv2_space():
ks_list = [int(np.random.choice([3, 5, 7 ])) for i in range(20)]
ex_list = [float(np.random.choice([3, 4, 6])) for i in range(20)]
d_list = [int(np.random.choice([2, 3, 4])) for i in range(5)]
arch_str = ArchTool.serialize(ks_list, ex_list, d_list)
return ks_list, ex_list, d_list, arch_str
def random_resnet50_space():
ks_list = [int(np.random.choice([3, ])) for i in range(20)]
ex_list = [float(np.random.choice([0.2, 0.25, 0.35])) for i in range(20)]
d_list = [int(np.random.choice([0, 1, 2])) for i in range(5)]
arch_str = ArchTool.serialize(ks_list, ex_list, d_list)
return ks_list, ex_list, d_list, arch_str
def main():
ks_list, ex_list, d_list, cfg_str = random_resnet50_space()
worker = Slave()
top1 = worker.run("cuda:0", net_id="ofa_resnet50", net_cfg=cfg_str)
print(top1)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment