Last active
August 17, 2017 06:52
-
-
Save leVirve/7a2bf775095a40261002e64abcf4268e to your computer and use it in GitHub Desktop.
Override the behavior of `forward()` inside VGG from PyTorch torchvision.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import types | |
import torch | |
from torch.autograd import Variable | |
import torchvision.models as models | |
import torchvision.transforms as transforms | |
from torchvision.datasets.folder import pil_loader | |
def sb_forward(self, x): | |
# conv1 | |
x1 = self.features[0](x) | |
x1 = self.features[1](x1) | |
# conv2 | |
x2 = self.features[2](x1) | |
x2 = self.features[3](x2) | |
# pool1 | |
x3 = self.features[4](x2) | |
return x1, x2, x3 | |
def extract(img=None): | |
img = img if img is not None else torch.randn((1, 3, 512, 512)) | |
feature_maps = vgg(Variable(img, requires_grad=False)) | |
return feature_maps | |
def numpy_feature_maps(feature_maps): | |
def tensor_data(x): | |
x = x.squeeze(0).permute(1, 2, 0) | |
data = x.data.numpy() | |
data = data / (data.max() - data.min()) | |
return data | |
return [tensor_data(x) for x in feature_maps] | |
def make_input(img): | |
transform = transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Normalize( | |
mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225]) | |
]) | |
return transform(img).unsqueeze(0) | |
def load_img(path): | |
return pil_loader(path) | |
# VGG model & replace the behavior of 'forward()' | |
vgg = models.vgg16(pretrained=True) | |
vgg.forward = types.MethodType(sb_forward, vgg) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import click | |
import numpy as np | |
from skimage.segmentation import slic, mark_boundaries | |
import matplotlib.pyplot as plt | |
from feature import extract, load_img, make_input, numpy_feature_maps | |
@click.command() | |
@click.option('--path', default='imgs/2white.jpg') | |
@click.option('-fe', '--feature_extract', is_flag=True) | |
@click.option('-n', '--n_segments', default=500) | |
@click.option('-c', '--compactness', default=0.1) | |
def main(path, feature_extract, n_segments, compactness): | |
image = load_img(path) | |
if feature_extract: | |
feature_maps = extract(make_input(image)) | |
features = numpy_feature_maps(feature_maps) | |
np.save('features', features) | |
print('==> Feature extracted and saved.') | |
if not feature_extract: | |
features = np.load('features.npy') | |
print('==> Feature loaded from file.') | |
features = np.dstack(features[:2]) | |
segments = slic(features, n_segments=n_segments, compactness=compactness) | |
img_show(image, segments) | |
def img_show(image, segments): | |
fig = plt.figure("Superpixels") | |
ax = fig.add_subplot(1, 1, 1) | |
ax.imshow(mark_boundaries(image, segments)) | |
plt.axis("off") | |
plt.show() | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment