Skip to content

Instantly share code, notes, and snippets.

@buttercutter
Last active May 20, 2020 06:37
Show Gist options
  • Save buttercutter/8ddd9794f242c24ffdaa612bcb0bfa33 to your computer and use it in GitHub Desktop.
Save buttercutter/8ddd9794f242c24ffdaa612bcb0bfa33 to your computer and use it in GitHub Desktop.
YOLOv3 + AdderNet
'''
Copyright (C) 2020. Huawei Technologies Co., Ltd. All rights reserved.
This program is free software; you can redistribute it and/or modify
it under the terms of BSD 3-Clause License.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
BSD 3-Clause License for more details.
'''
import torch
import torch.nn as nn
import numpy as np
from torch.autograd import Function
import math
# https://github.com/pytorch/pytorch/issues/15253#issuecomment-491467128
@torch.jit.script
def my_cdist(x1, x2, p:int):
x1_norm = x1.pow(p).sum(dim=-1, keepdim=True)
x2_norm = x2.pow(p).sum(dim=-1, keepdim=True)
res = torch.addmm(x2_norm.transpose(-2, -1), x1, x2.transpose(-2, -1), alpha=-2).add_(x1_norm)
res = res.clamp_min_(1e-30).sqrt_()
return res
# https://github.com/pytorch/pytorch/pull/25799#issuecomment-529021810
def fast_cdist(x1, x2, p:int):
adjustment = x1.mean(-2, keepdim=True)
x1 = x1 - adjustment
x2 = x2 - adjustment # x1 and x2 should be identical in all dims except -2 at this point
# Compute distance matrix
# But be clever and do it with a single matmul call
x1_norm = x1.pow(p).sum(dim=-1, keepdim=True)
x1_pad = torch.ones_like(x1_norm)
x2_norm = x2.pow(p).sum(dim=-1, keepdim=True)
x2_pad = torch.ones_like(x2_norm)
x1_ = torch.cat([-2. * x1, x1_norm, x1_pad], dim=-1)
x2_ = torch.cat([x2, x2_pad, x2_norm], dim=-1)
res = x1_.matmul(x2_.transpose(-2, -1))
# Zero out negative values
res.clamp_min_(1e-30).sqrt_()
return res
def new_cdist(p, eta): ## https://github.com/huawei-noah/AdderNet/issues/6#issuecomment-594212162
class cdist(torch.autograd.Function):
@staticmethod
def forward(ctx, W, X):
ctx.save_for_backward(W, X)
out = -my_cdist(W, X, p)
return out
@staticmethod
def backward(ctx, grad_output):
W, X = ctx.saved_tensors
grad_W = grad_X = None
if ctx.needs_input_grad[0]:
_temp1 = torch.unsqueeze(X, 2).expand(X.shape[0], X.shape[1], W.shape[0]).permute(1, 0, 2)
_temp2 = torch.unsqueeze(W.transpose(0, 1), 1)
_temp = my_cdist(_temp1, _temp2, p).squeeze().transpose(0, 1)
grad_W = torch.matmul(grad_output, _temp)
# print('before norm: ', torch.norm(grad_W))
grad_W = eta * np.sqrt(grad_W.numel()) / torch.norm(grad_W) * grad_W
print('after norm: ', torch.norm(grad_W))
if ctx.needs_input_grad[1]:
_temp1 = torch.unsqueeze(W, 2).expand(W.shape[0], W.shape[1], X.shape[0]).permute(1, 0, 2)
_temp2 = torch.unsqueeze(X.transpose(0, 1), 1)
_temp = my_cdist(_temp1, _temp2, p).squeeze().transpose(0, 1)
_temp = torch.nn.functional.hardtanh(_temp, min_val=-1., max_val=1.)
grad_X = torch.matmul(grad_output.transpose(0, 1), _temp)
return grad_W, grad_X
return cdist().apply
def adder2d_function(X, W, stride=1, padding=0):
n_filters, d_filter, h_filter, w_filter = W.size()
n_x, d_x, h_x, w_x = X.size()
h_out = (h_x - h_filter + 2 * padding) / stride + 1
w_out = (w_x - w_filter + 2 * padding) / stride + 1
h_out, w_out = int(h_out), int(w_out)
X_col = torch.nn.functional.unfold(X.view(1, -1, h_x, w_x), h_filter, dilation=1, padding=padding, stride=stride).view(n_x, -1, h_out*w_out)
X_col = X_col.permute(1,2,0).contiguous().view(X_col.size(1),-1)
W_col = W.view(n_filters, -1)
cdist = new_cdist(1, 0.2) ## https://github.com/huawei-noah/AdderNet/issues/9
out = -cdist(W_col,X_col.transpose(0,1))
out = out.view(n_filters, h_out, w_out, n_x)
out = out.permute(3, 0, 1, 2).contiguous()
return out
class adder2d(nn.Module):
def __init__(self,input_channel,output_channel,kernel_size, stride=1, padding=0, bias = False):
super(adder2d, self).__init__()
self.stride = stride
self.padding = padding
self.input_channel = input_channel
self.output_channel = output_channel
self.kernel_size = kernel_size
self.adder = torch.nn.Parameter(nn.init.normal_(torch.randn(output_channel,input_channel,kernel_size,kernel_size)))
self.bias = bias
if bias:
self.b = torch.nn.Parameter(nn.init.uniform_(torch.zeros(output_channel)))
def forward(self, x):
output = adder2d_function(x,self.adder, self.stride, self.padding)
if self.bias:
output += self.b.unsqueeze(0).unsqueeze(2).unsqueeze(3)
return output
'''
Copyright (C) 2020. Huawei Technologies Co., Ltd. All rights reserved.
This program is free software; you can redistribute it and/or modify
it under the terms of BSD 3-Clause License.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
BSD 3-Clause License for more details.
'''
import torch
import torch.nn as nn
import numpy as np
from torch.autograd import Function
import math
# https://github.com/pytorch/pytorch/issues/15253#issuecomment-491467128
@torch.jit.script
def my_cdist(x1, x2, p:int):
x1_norm = x1.pow(p).sum(dim=-1, keepdim=True)
x2_norm = x2.pow(p).sum(dim=-1, keepdim=True)
res = torch.addmm(x2_norm.transpose(-2, -1), x1, x2.transpose(-2, -1), alpha=-2).add_(x1_norm)
res = res.clamp_min_(1e-30).sqrt_()
return res
# https://github.com/pytorch/pytorch/pull/25799#issuecomment-529021810
def fast_cdist(x1, x2, p:int):
adjustment = x1.mean(-2, keepdim=True)
x1 = x1 - adjustment
x2 = x2 - adjustment # x1 and x2 should be identical in all dims except -2 at this point
# Compute distance matrix
# But be clever and do it with a single matmul call
x1_norm = x1.pow(p).sum(dim=-1, keepdim=True)
x1_pad = torch.ones_like(x1_norm)
x2_norm = x2.pow(p).sum(dim=-1, keepdim=True)
x2_pad = torch.ones_like(x2_norm)
x1_ = torch.cat([-2. * x1, x1_norm, x1_pad], dim=-1)
x2_ = torch.cat([x2, x2_pad, x2_norm], dim=-1)
res = x1_.matmul(x2_.transpose(-2, -1))
# Zero out negative values
res.clamp_min_(1e-30).sqrt_()
return res
def new_cdist(p, eta): ## https://github.com/huawei-noah/AdderNet/issues/6#issuecomment-594212162
class cdist(torch.autograd.Function):
@staticmethod
def forward(ctx, W, X):
ctx.save_for_backward(W, X)
out = -fast_cdist(W, X, p)
return out
@staticmethod
def backward(ctx, grad_output):
W, X = ctx.saved_tensors
grad_W = grad_X = None
if ctx.needs_input_grad[0]:
_temp1 = torch.unsqueeze(X, 2).expand(X.shape[0], X.shape[1], W.shape[0]).permute(1, 0, 2)
_temp2 = torch.unsqueeze(W.transpose(0, 1), 1)
_temp = fast_cdist(_temp1, _temp2, p).squeeze().transpose(0, 1)
grad_W = torch.matmul(grad_output, _temp)
# print('before norm: ', torch.norm(grad_W))
grad_W = eta * np.sqrt(grad_W.numel()) / torch.norm(grad_W) * grad_W
print('after norm: ', torch.norm(grad_W))
if ctx.needs_input_grad[1]:
_temp1 = torch.unsqueeze(W, 2).expand(W.shape[0], W.shape[1], X.shape[0]).permute(1, 0, 2)
_temp2 = torch.unsqueeze(X.transpose(0, 1), 1)
_temp = fast_cdist(_temp1, _temp2, p).squeeze().transpose(0, 1)
_temp = torch.nn.functional.hardtanh(_temp, min_val=-1., max_val=1.)
grad_X = torch.matmul(grad_output.transpose(0, 1), _temp)
return grad_W, grad_X
return cdist().apply
def adder2d_function(X, W, stride=1, padding=0):
n_filters, d_filter, h_filter, w_filter = W.size()
n_x, d_x, h_x, w_x = X.size()
h_out = (h_x - h_filter + 2 * padding) / stride + 1
w_out = (w_x - w_filter + 2 * padding) / stride + 1
h_out, w_out = int(h_out), int(w_out)
X_col = torch.nn.functional.unfold(X.view(1, -1, h_x, w_x), h_filter, dilation=1, padding=padding, stride=stride).view(n_x, -1, h_out*w_out)
X_col = X_col.permute(1,2,0).contiguous().view(X_col.size(1),-1)
W_col = W.view(n_filters, -1)
cdist = new_cdist(1, 0.2) ## https://github.com/huawei-noah/AdderNet/issues/9
out = -cdist(W_col,X_col.transpose(0,1))
out = out.view(n_filters, h_out, w_out, n_x)
out = out.permute(3, 0, 1, 2).contiguous()
return out
class adder2d(nn.Module):
def __init__(self,input_channel,output_channel,kernel_size, stride=1, padding=0, bias = False):
super(adder2d, self).__init__()
self.stride = stride
self.padding = padding
self.input_channel = input_channel
self.output_channel = output_channel
self.kernel_size = kernel_size
self.adder = torch.nn.Parameter(nn.init.normal_(torch.randn(output_channel,input_channel,kernel_size,kernel_size)))
self.bias = bias
if bias:
self.b = torch.nn.Parameter(nn.init.uniform_(torch.zeros(output_channel)))
def forward(self, x):
output = adder2d_function(x,self.adder, self.stride, self.padding)
if self.bias:
output += self.b.unsqueeze(0).unsqueeze(2).unsqueeze(3)
return output
aeroplane
bicycle
bird
boat
bottle
bus
car
cat
chair
cow
diningtable
dog
horse
motorbike
person
pottedplant
sheep
sofa
train
tvmonitor
#!/bin/bash
NUM_CLASSES=$1
echo "
[net]
# Testing
#batch=1
#subdivisions=1
# Training
batch=2
subdivisions=2
width=416
height=416
channels=3
momentum=0.9
decay=0.0005
angle=0
saturation = 1.5
exposure = 1.5
hue=.1
learning_rate=0.001
burn_in=1000
max_batches = 500200
policy=steps
steps=400000,450000
scales=.1,.1
[convolutional]
batch_normalize=1
filters=32
size=3
stride=1
pad=1
activation=leaky
# Downsample
[convolutional]
batch_normalize=1
filters=64
size=3
stride=2
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=32
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=64
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
# Downsample
[convolutional]
batch_normalize=1
filters=128
size=3
stride=2
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=64
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=128
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
[convolutional]
batch_normalize=1
filters=64
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=128
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
# Downsample
[convolutional]
batch_normalize=1
filters=256
size=3
stride=2
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=128
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=256
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
[convolutional]
batch_normalize=1
filters=128
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=256
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
[convolutional]
batch_normalize=1
filters=128
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=256
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
[convolutional]
batch_normalize=1
filters=128
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=256
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
[convolutional]
batch_normalize=1
filters=128
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=256
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
[convolutional]
batch_normalize=1
filters=128
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=256
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
[convolutional]
batch_normalize=1
filters=128
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=256
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
[convolutional]
batch_normalize=1
filters=128
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=256
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
# Downsample
[convolutional]
batch_normalize=1
filters=512
size=3
stride=2
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=256
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=512
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
[convolutional]
batch_normalize=1
filters=256
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=512
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
[convolutional]
batch_normalize=1
filters=256
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=512
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
[convolutional]
batch_normalize=1
filters=256
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=512
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
[convolutional]
batch_normalize=1
filters=256
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=512
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
[convolutional]
batch_normalize=1
filters=256
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=512
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
[convolutional]
batch_normalize=1
filters=256
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=512
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
[convolutional]
batch_normalize=1
filters=256
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=512
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
# Downsample
[convolutional]
batch_normalize=1
filters=1024
size=3
stride=2
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=512
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=1024
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
[convolutional]
batch_normalize=1
filters=512
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=1024
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
[convolutional]
batch_normalize=1
filters=512
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=1024
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
[convolutional]
batch_normalize=1
filters=512
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
filters=1024
size=3
stride=1
pad=1
activation=leaky
[shortcut]
from=-3
activation=linear
######################
[convolutional]
batch_normalize=1
filters=512
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
size=3
stride=1
pad=1
filters=1024
activation=leaky
[convolutional]
batch_normalize=1
filters=512
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
size=3
stride=1
pad=1
filters=1024
activation=leaky
[convolutional]
batch_normalize=1
filters=512
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
size=3
stride=1
pad=1
filters=1024
activation=leaky
[convolutional]
size=1
stride=1
pad=1
filters=$(expr 3 \* $(expr $NUM_CLASSES \+ 5))
activation=linear
[yolo]
mask = 6,7,8
anchors = 10,13, 16,30, 33,23, 30,61, 62,45, 59,119, 116,90, 156,198, 373,326
classes=$NUM_CLASSES
num=9
jitter=.3
ignore_thresh = .7
truth_thresh = 1
random=1
[route]
layers = -4
[convolutional]
batch_normalize=1
filters=256
size=1
stride=1
pad=1
activation=leaky
[upsample]
stride=2
[route]
layers = -1, 61
[convolutional]
batch_normalize=1
filters=256
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
size=3
stride=1
pad=1
filters=512
activation=leaky
[convolutional]
batch_normalize=1
filters=256
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
size=3
stride=1
pad=1
filters=512
activation=leaky
[convolutional]
batch_normalize=1
filters=256
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
size=3
stride=1
pad=1
filters=512
activation=leaky
[convolutional]
size=1
stride=1
pad=1
filters=$(expr 3 \* $(expr $NUM_CLASSES \+ 5))
activation=linear
[yolo]
mask = 3,4,5
anchors = 10,13, 16,30, 33,23, 30,61, 62,45, 59,119, 116,90, 156,198, 373,326
classes=$NUM_CLASSES
num=9
jitter=.3
ignore_thresh = .7
truth_thresh = 1
random=1
[route]
layers = -4
[convolutional]
batch_normalize=1
filters=128
size=1
stride=1
pad=1
activation=leaky
[upsample]
stride=2
[route]
layers = -1, 36
[convolutional]
batch_normalize=1
filters=128
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
size=3
stride=1
pad=1
filters=256
activation=leaky
[convolutional]
batch_normalize=1
filters=128
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
size=3
stride=1
pad=1
filters=256
activation=leaky
[convolutional]
batch_normalize=1
filters=128
size=1
stride=1
pad=1
activation=leaky
[convolutional]
batch_normalize=1
size=3
stride=1
pad=1
filters=256
activation=leaky
[convolutional]
size=1
stride=1
pad=1
filters=$(expr 3 \* $(expr $NUM_CLASSES \+ 5))
activation=linear
[yolo]
mask = 0,1,2
anchors = 10,13, 16,30, 33,23, 30,61, 62,45, 59,119, 116,90, 156,198, 373,326
classes=$NUM_CLASSES
num=9
jitter=.3
ignore_thresh = .7
truth_thresh = 1
random=1
" >> yolov3-custom.cfg
import os
import argparse
import shutil
import tarfile
import sys
from pathlib import Path
import json
image_types = ('.jpg', '.jpeg', '.jpe', '.img', '.png', '.bmp')
def parser():
parser = argparse.ArgumentParser(description=' ')
parser.add_argument('--source_archive_dir',
type=str,
required=False,
help='Full path to the source archive')
parser.add_argument('--source_images_archive_dir',
type=str,
required=False,
help='Full path to the source archive')
parser.add_argument('--source_annotations_archive_dir',
type=str,
required=False,
help='Full path to the source archive')
parser.add_argument('--output_size',
type=int,
required=True,
help='Number of images in the output dataset')
parser.add_argument('--first_image',
type=int,
required=False,
default=0,
help='Number of the image to start from')
parser.add_argument('--output_archive_dir',
type=str,
required=True,
help='Full path to the output archive (without the name of the archive)')
parser.add_argument('--dataset_type',
type=str,
choices=['imagenet','voc', 'coco'],
required=True,
help='Dataset format: ImageNet, Pascal VOC, or COCO')
return parser
def unarchive(source_archive_dir, output_folder_dir):
shutil.unpack_archive(source_archive_dir, output_folder_dir)
def is_possible_to_cut(dataset_size, subset_size, first_image):
return first_image < dataset_size - subset_size
def cut_imagenet(output_size, output_folder_dir, first_image):
file_names = os.listdir(output_folder_dir)
image_names = []
text_files = []
for file_name in file_names:
if file_name.lower().endswith('.txt'):
text_files.append(file_name)
if len(text_files) > 1:
sys.exit('Incorrect dataset format.')
else:
annotation_name = file_name
elif file_name.lower().endswith(image_types):
image_names.append(file_name)
image_ext = os.path.splitext(image_names[0])[1]
if not image_names:
sys.exit('Incorrect dataset format.')
if not is_possible_to_cut(len(image_names), output_size, first_image):
sys.exit('Invalid --first_image value. The number of the starting image should be less than the difference\n'
'between the dataset size and the subset size.')
annotation_path = os.path.join(output_folder_dir, annotation_name)
with open(annotation_path, 'r') as annotation:
annotation_text = annotation.readlines()
new_annotation_text = annotation_text[first_image:output_size+first_image]
with open(annotation_path, 'w') as new_annotation:
for line in new_annotation_text:
new_annotation.write(line)
new_file_names = [annotation_name, ]
for line in new_annotation_text:
new_file_names.append('{}{}'.format(os.path.splitext(line.split()[0])[0], image_ext))
files_to_archive = new_file_names
return (files_to_archive, '',)
def cut_voc(output_size, output_folder_dir, first_image):
voc_folder = os.listdir(output_folder_dir)[0]
if voc_folder == 'TrainVal':
voc_devkit_folder_dir = os.path.join(output_folder_dir, voc_folder)
voc_devkit_folder = os.listdir(voc_devkit_folder_dir)[0]
voc_year_folder_dir = os.path.join(voc_devkit_folder_dir, voc_devkit_folder)
voc_year_folder = os.listdir(voc_year_folder_dir)[0]
else:
voc_year_folder_dir = os.path.join(output_folder_dir, voc_folder)
voc_year_folder = os.listdir(voc_year_folder_dir)[0]
voc_root_dir = os.path.join(voc_year_folder_dir, voc_year_folder)
voc_content_root_folders = os.listdir(voc_root_dir)
annotation_dir = os.path.join(voc_root_dir, 'Annotations')
for element in voc_content_root_folders:
path_to_element = os.path.join(voc_root_dir, element)
if os.path.isdir(path_to_element) and 'Images' in element:
images_dir = path_to_element
images_files = os.listdir(images_dir)
if not is_possible_to_cut(len(images_files), output_size, first_image):
sys.exit('Invalid --first_image value. The number of the starting image should be less than the difference\n'
'between the dataset and subset sizes.')
images_files = images_files[first_image:first_image+output_size]
main_dir = os.path.join(voc_root_dir, 'ImageSets', 'Main')
if (not os.path.isdir(annotation_dir) or not os.path.isdir(main_dir)
or not os.path.isdir(images_dir)):
sys.exit('Incorrect dataset format.')
names = []
files_directories = []
for images_file in images_files:
img_name = os.path.splitext(images_file)[0]
annotation = '{}.xml'.format(os.path.join(annotation_dir, img_name))
if images_file.lower().endswith(image_types) and os.path.isfile(annotation):
names.append(img_name)
files_directories.append(os.path.join(images_dir, images_file))
if not names:
sys.exit('Incorrect dataset format.')
for name in names:
files_directories.append('{}.xml'.format(os.path.join(annotation_dir, name)))
possible_names = ('test.txt', 'trainval.txt', 'val.txt')
main_txt_dir = None
for name in possible_names:
if os.path.isfile(os.path.join(main_dir, name)):
main_txt_dir = os.path.join(main_dir, name)
break
if not os.path.isfile(main_txt_dir):
sys.exit('Incorrect dataset format')
with open(main_txt_dir, 'w') as main:
main.write('\n'.join(names))
files_directories.append(main_txt_dir)
return (files_directories, 'VOCdevkit',)
def cut_coco(output_size, output_folder_dir, first_image):
num_of_folders = 2
root_folders = os.listdir(output_folder_dir)
if len(root_folders) != num_of_folders:
sys.exit('Incorrect dataset format.')
annotations_folder = str(next(Path(output_folder_dir).glob('annotations')))
images_folder_dir = os.path.join(output_folder_dir, str(next(Path(output_folder_dir).glob('val*[0-9]'))))
images_folder = os.listdir(images_folder_dir)
annotation_name = next(Path(annotations_folder).glob('instances_val*[0-9].json'))
annotation_dir = os.path.join(str(annotations_folder), str(annotation_name))
annotation_name_train = next(Path(annotations_folder).glob('instances_train*[0-9].json'))
if annotation_name_train:
annotation_dir_train = os.path.join(str(annotations_folder), str(annotation_name_train))
os.remove(annotation_dir_train)
if not images_folder or not annotation_name:
sys.exit('Incorrect dataset format.')
if not is_possible_to_cut(len(images_folder), output_size, first_image):
sys.exit('Invalid --first_image value. The number of the starting image should be less than the difference '
'between the dataset size and the subset size.')
with open(annotation_dir) as json_file:
json_data = json.load(json_file)
json_data['images'] = json_data['images'][first_image:output_size+first_image]
image_filenames = []
image_ids = []
for image in json_data['images']:
image_ids.append(image['id'])
image_filenames.append(image['file_name'])
annotations = json_data['annotations']
cut_annotations = []
for annotation in annotations:
if annotation['image_id'] in image_ids:
cut_annotations.append(annotation)
json_data['annotations'] = cut_annotations
with open(annotation_name, 'w') as outfile:
json.dump(json_data, outfile)
new_image_filenames = []
for image in image_filenames:
new_image_filenames.append(os.path.join(images_folder_dir, image))
files_to_archive = new_image_filenames.copy()
files_to_archive.append(annotations_folder)
return (files_to_archive, 'subset_folder',)
def archive(new_file_names, source_path, output_archive_name, output_folder_dir, rel_path_finder):
with tarfile.open(os.path.join(source_path, '{}.tar.gz'.format(output_archive_name)), 'w:gz') as tar:
for file_name in new_file_names:
relative_path = '{}'.format(file_name[file_name.find(rel_path_finder):])
tar.add(os.path.join(output_folder_dir, file_name), arcname=relative_path)
def clean_up(path):
shutil.rmtree(path)
def is_imagenet(dataset_type):
return dataset_type == 'imagenet'
def is_voc(dataset_type):
return dataset_type == 'voc'
def is_coco(dataset_type):
return dataset_type == 'coco'
if __name__ == '__main__':
args = parser().parse_args()
output_folder_dir = os.path.join(args.output_archive_dir, 'subset_folder')
output_archive_name = '{}_subset_{}_{}'.format(args.dataset_type, args.first_image, args.first_image + args.output_size - 1)
if is_imagenet(args.dataset_type) and not args.source_archive_dir:
sys.exit('--source_archive_dir is required for the selected dataset type.')
if is_voc(args.dataset_type) and not args.source_archive_dir:
sys.exit('--source_archive_dir is required for the selected dataset type.')
if is_coco(args.dataset_type) and (not args.source_images_archive_dir or not args.source_annotations_archive_dir):
sys.exit('Both --source_images_archive_dir and --source_annotations_archive_dir are required for the selected dataset type.')
if is_imagenet(args.dataset_type):
unarchive(args.source_archive_dir, output_folder_dir)
imagenet_data = cut_imagenet(args.output_size, output_folder_dir, args.first_image)
new_file_names = imagenet_data[0]
rel_path_finder = imagenet_data[1]
elif is_voc(args.dataset_type):
unarchive(args.source_archive_dir, output_folder_dir)
voc_data = cut_voc(args.output_size, output_folder_dir, args.first_image)
new_file_names = voc_data[0]
rel_path_finder = voc_data[1]
else:
unarchive(args.source_images_archive_dir, output_folder_dir)
unarchive(args.source_annotations_archive_dir, output_folder_dir)
coco_data = cut_coco(args.output_size, output_folder_dir, args.first_image)
new_file_names = coco_data[0]
rel_path_finder = coco_data[1]
archive(new_file_names, args.output_archive_dir, output_archive_name, output_folder_dir, rel_path_finder)
clean_up(output_folder_dir)
import glob
import random
import os
import sys
import numpy as np
from PIL import Image
import torch
import torch.nn.functional as F
from utils.augmentations import horisontal_flip
from torch.utils.data import Dataset
import torchvision.transforms as transforms
def pad_to_square(img, pad_value):
c, h, w = img.shape
dim_diff = np.abs(h - w)
# (upper / left) padding and (lower / right) padding
pad1, pad2 = dim_diff // 2, dim_diff - dim_diff // 2
# Determine padding
pad = (0, 0, pad1, pad2) if h <= w else (pad1, pad2, 0, 0)
# Add padding
img = F.pad(img, pad, "constant", value=pad_value)
return img, pad
def resize(image, size):
image = F.interpolate(image.unsqueeze(0), size=size, mode="nearest").squeeze(0)
return image
def random_resize(images, min_size=288, max_size=448):
new_size = random.sample(list(range(min_size, max_size + 1, 32)), 1)[0]
images = F.interpolate(images, size=new_size, mode="nearest")
return images
class ImageFolder(Dataset):
def __init__(self, folder_path, img_size=416):
self.files = sorted(glob.glob("%s/*.*" % folder_path))
self.img_size = img_size
def __getitem__(self, index):
img_path = self.files[index % len(self.files)]
# Extract image as PyTorch tensor
img = transforms.ToTensor()(Image.open(img_path))
# Pad to square resolution
img, _ = pad_to_square(img, 0)
# Resize
img = resize(img, self.img_size)
return img_path, img
def __len__(self):
return len(self.files)
class ListDataset(Dataset):
def __init__(self, list_path, img_size=416, augment=True, multiscale=True, normalized_labels=True):
with open(list_path, "r") as file:
self.img_files = file.readlines()
self.label_files = [
path.replace("JPEGImages", "labels").replace(".png", ".txt").replace(".jpg", ".txt")
for path in self.img_files
]
self.img_size = img_size
self.max_objects = 100
self.augment = augment
self.multiscale = multiscale
self.normalized_labels = normalized_labels
self.min_size = self.img_size - 3 * 32
self.max_size = self.img_size + 3 * 32
self.batch_count = 0
def __getitem__(self, index):
# ---------
# Image
# ---------
img_path = self.img_files[index % len(self.img_files)].rstrip()
# Extract image as PyTorch tensor
img = transforms.ToTensor()(Image.open(img_path).convert('RGB'))
# Handle images with less than three channels
if len(img.shape) != 3:
img = img.unsqueeze(0)
img = img.expand((3, img.shape[1:]))
_, h, w = img.shape
h_factor, w_factor = (h, w) if self.normalized_labels else (1, 1)
# Pad to square resolution
img, pad = pad_to_square(img, 0)
_, padded_h, padded_w = img.shape
# ---------
# Label
# ---------
label_path = self.label_files[index % len(self.img_files)].rstrip()
targets = None
if os.path.exists(label_path):
boxes = torch.from_numpy(np.loadtxt(label_path).reshape(-1, 5))
# Extract coordinates for unpadded + unscaled image
x1 = w_factor * (boxes[:, 1] - boxes[:, 3] / 2)
y1 = h_factor * (boxes[:, 2] - boxes[:, 4] / 2)
x2 = w_factor * (boxes[:, 1] + boxes[:, 3] / 2)
y2 = h_factor * (boxes[:, 2] + boxes[:, 4] / 2)
# Adjust for added padding
x1 += pad[0]
y1 += pad[2]
x2 += pad[1]
y2 += pad[3]
# Returns (x, y, w, h)
boxes[:, 1] = ((x1 + x2) / 2) / padded_w
boxes[:, 2] = ((y1 + y2) / 2) / padded_h
boxes[:, 3] *= w_factor / padded_w
boxes[:, 4] *= h_factor / padded_h
targets = torch.zeros((len(boxes), 6))
targets[:, 1:] = boxes
# Apply augmentations
if self.augment:
if np.random.random() < 0.5:
img, targets = horisontal_flip(img, targets)
return img_path, img, targets
def collate_fn(self, batch):
paths, imgs, targets = list(zip(*batch))
# Remove empty placeholder targets
targets = [boxes for boxes in targets if boxes is not None]
# Add sample index to targets
for i, boxes in enumerate(targets):
boxes[:, 0] = i
targets = torch.cat(targets, 0)
# Selects new image size every tenth batch
if self.multiscale and self.batch_count % 10 == 0:
self.img_size = random.choice(range(self.min_size, self.max_size + 1, 32))
# Resize images to input shape
imgs = torch.stack([resize(img, self.img_size) for img in imgs])
self.batch_count += 1
return paths, imgs, targets
def __len__(self):
return len(self.img_files)
import tensorflow as tf
class Logger(object):
def __init__(self, log_dir):
"""Create a summary writer logging to log_dir."""
#self.writer = tf.summary.FileWriter(log_dir)
self.writer = tf.summary.create_file_writer(log_dir)
def scalar_summary(self, tag, value, step):
"""Log a scalar variable."""
#summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)])
#self.writer.add_summary(summary, step)
with self.writer.as_default():
tf.summary.scalar(tag, value, step=step)
self.writer.flush()
def list_of_scalars_summary(self, tag_value_pairs, step):
"""Log scalar variables."""
#summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value) for tag, value in tag_value_pairs])
#self.writer.add_summary(summary, step)
with self.writer.as_default():
for tag, value in tag_value_pairs:
tf.summary.scalar(tag, value, step=step)
self.writer.flush()
from __future__ import division
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
from utils.parse_config import *
from utils.utils import build_targets, to_cpu, non_max_suppression
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import adder ###########
def conv2d(in_channels, out_channels, kernel_size, stride, padding, bias):
"""3x3 or 1x1 convolution with padding"""
return adder.adder2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride,
padding=padding, bias=bias)
def create_modules(module_defs):
"""
Constructs module list of layer blocks from module configuration in module_defs
"""
hyperparams = module_defs.pop(0)
output_filters = [int(hyperparams["channels"])]
module_list = nn.ModuleList()
for module_i, module_def in enumerate(module_defs):
modules = nn.Sequential()
if module_def["type"] == "convolutional":
bn = int(module_def["batch_normalize"])
filters = int(module_def["filters"])
kernel_size = int(module_def["size"])
pad = (kernel_size - 1) // 2
modules.add_module(
f"conv_{module_i}",
# nn.Conv2d( #####################
conv2d(
in_channels=output_filters[-1],
out_channels=filters,
kernel_size=kernel_size,
stride=int(module_def["stride"]),
padding=pad,
bias=not bn,
),
)
if bn:
modules.add_module(f"batch_norm_{module_i}", nn.BatchNorm2d(filters, momentum=0.9, eps=1e-5))
if module_def["activation"] == "leaky":
modules.add_module(f"leaky_{module_i}", nn.LeakyReLU(0.1))
elif module_def["type"] == "maxpool":
kernel_size = int(module_def["size"])
stride = int(module_def["stride"])
if kernel_size == 2 and stride == 1:
modules.add_module(f"_debug_padding_{module_i}", nn.ZeroPad2d((0, 1, 0, 1)))
maxpool = nn.MaxPool2d(kernel_size=kernel_size, stride=stride, padding=int((kernel_size - 1) // 2))
modules.add_module(f"maxpool_{module_i}", maxpool)
elif module_def["type"] == "upsample":
upsample = Upsample(scale_factor=int(module_def["stride"]), mode="nearest")
modules.add_module(f"upsample_{module_i}", upsample)
elif module_def["type"] == "route":
layers = [int(x) for x in module_def["layers"].split(",")]
filters = sum([output_filters[1:][i] for i in layers])
modules.add_module(f"route_{module_i}", EmptyLayer())
elif module_def["type"] == "shortcut":
filters = output_filters[1:][int(module_def["from"])]
modules.add_module(f"shortcut_{module_i}", EmptyLayer())
elif module_def["type"] == "yolo":
anchor_idxs = [int(x) for x in module_def["mask"].split(",")]
# Extract anchors
anchors = [int(x) for x in module_def["anchors"].split(",")]
anchors = [(anchors[i], anchors[i + 1]) for i in range(0, len(anchors), 2)]
anchors = [anchors[i] for i in anchor_idxs]
num_classes = int(module_def["classes"])
img_size = int(hyperparams["height"])
# Define detection layer
yolo_layer = YOLOLayer(anchors, num_classes, img_size)
modules.add_module(f"yolo_{module_i}", yolo_layer)
# Register module list and number of output filters
module_list.append(modules)
output_filters.append(filters)
return hyperparams, module_list
class Upsample(nn.Module):
""" nn.Upsample is deprecated """
def __init__(self, scale_factor, mode="nearest"):
super(Upsample, self).__init__()
self.scale_factor = scale_factor
self.mode = mode
def forward(self, x):
x = F.interpolate(x, scale_factor=self.scale_factor, mode=self.mode)
return x
class EmptyLayer(nn.Module):
"""Placeholder for 'route' and 'shortcut' layers"""
def __init__(self):
super(EmptyLayer, self).__init__()
class YOLOLayer(nn.Module):
"""Detection layer"""
def __init__(self, anchors, num_classes, img_dim=416):
super(YOLOLayer, self).__init__()
self.anchors = anchors
self.num_anchors = len(anchors)
self.num_classes = num_classes
self.ignore_thres = 0.5
self.mse_loss = nn.MSELoss()
self.bce_loss = nn.BCELoss()
self.obj_scale = 1
self.noobj_scale = 100
self.metrics = {}
self.img_dim = img_dim
self.grid_size = 0 # grid size
def compute_grid_offsets(self, grid_size, cuda=True):
self.grid_size = grid_size
g = self.grid_size
FloatTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
self.stride = self.img_dim / self.grid_size
# Calculate offsets for each grid
self.grid_x = torch.arange(g).repeat(g, 1).view([1, 1, g, g]).type(FloatTensor)
self.grid_y = torch.arange(g).repeat(g, 1).t().view([1, 1, g, g]).type(FloatTensor)
self.scaled_anchors = FloatTensor([(a_w / self.stride, a_h / self.stride) for a_w, a_h in self.anchors])
self.anchor_w = self.scaled_anchors[:, 0:1].view((1, self.num_anchors, 1, 1))
self.anchor_h = self.scaled_anchors[:, 1:2].view((1, self.num_anchors, 1, 1))
def forward(self, x, targets=None, img_dim=None):
# Tensors for cuda support
FloatTensor = torch.cuda.FloatTensor if x.is_cuda else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if x.is_cuda else torch.LongTensor
ByteTensor = torch.cuda.ByteTensor if x.is_cuda else torch.ByteTensor
self.img_dim = img_dim
num_samples = x.size(0)
grid_size = x.size(2)
prediction = (
x.view(num_samples, self.num_anchors, self.num_classes + 5, grid_size, grid_size)
.permute(0, 1, 3, 4, 2)
.contiguous()
)
# Get outputs
x = torch.sigmoid(prediction[..., 0]) # Center x
y = torch.sigmoid(prediction[..., 1]) # Center y
w = prediction[..., 2] # Width
h = prediction[..., 3] # Height
pred_conf = torch.sigmoid(prediction[..., 4]) # Conf
pred_cls = torch.sigmoid(prediction[..., 5:]) # Cls pred.
# If grid size does not match current we compute new offsets
if grid_size != self.grid_size:
self.compute_grid_offsets(grid_size, cuda=x.is_cuda)
# Add offset and scale with anchors
pred_boxes = FloatTensor(prediction[..., :4].shape)
pred_boxes[..., 0] = x.data + self.grid_x
pred_boxes[..., 1] = y.data + self.grid_y
pred_boxes[..., 2] = torch.exp(w.data) * self.anchor_w
pred_boxes[..., 3] = torch.exp(h.data) * self.anchor_h
output = torch.cat(
(
pred_boxes.view(num_samples, -1, 4) * self.stride,
pred_conf.view(num_samples, -1, 1),
pred_cls.view(num_samples, -1, self.num_classes),
),
-1,
)
if targets is None:
return output, 0
else:
iou_scores, class_mask, obj_mask, noobj_mask, tx, ty, tw, th, tcls, tconf = build_targets(
pred_boxes=pred_boxes,
pred_cls=pred_cls,
target=targets,
anchors=self.scaled_anchors,
ignore_thres=self.ignore_thres,
)
# Loss : Mask outputs to ignore non-existing objects (except with conf. loss)
loss_x = self.mse_loss(x[obj_mask], tx[obj_mask])
loss_y = self.mse_loss(y[obj_mask], ty[obj_mask])
loss_w = self.mse_loss(w[obj_mask], tw[obj_mask])
loss_h = self.mse_loss(h[obj_mask], th[obj_mask])
loss_conf_obj = self.bce_loss(pred_conf[obj_mask], tconf[obj_mask])
loss_conf_noobj = self.bce_loss(pred_conf[noobj_mask], tconf[noobj_mask])
loss_conf = self.obj_scale * loss_conf_obj + self.noobj_scale * loss_conf_noobj
loss_cls = self.bce_loss(pred_cls[obj_mask], tcls[obj_mask])
total_loss = loss_x + loss_y + loss_w + loss_h + loss_conf + loss_cls
# Metrics
cls_acc = 100 * class_mask[obj_mask].mean()
conf_obj = pred_conf[obj_mask].mean()
conf_noobj = pred_conf[noobj_mask].mean()
conf50 = (pred_conf > 0.5).float()
iou50 = (iou_scores > 0.5).float()
iou75 = (iou_scores > 0.75).float()
detected_mask = conf50 * class_mask * tconf
precision = torch.sum(iou50 * detected_mask) / (conf50.sum() + 1e-16)
recall50 = torch.sum(iou50 * detected_mask) / (obj_mask.sum() + 1e-16)
recall75 = torch.sum(iou75 * detected_mask) / (obj_mask.sum() + 1e-16)
self.metrics = {
"loss": to_cpu(total_loss).item(),
"x": to_cpu(loss_x).item(),
"y": to_cpu(loss_y).item(),
"w": to_cpu(loss_w).item(),
"h": to_cpu(loss_h).item(),
"conf": to_cpu(loss_conf).item(),
"cls": to_cpu(loss_cls).item(),
"cls_acc": to_cpu(cls_acc).item(),
"recall50": to_cpu(recall50).item(),
"recall75": to_cpu(recall75).item(),
"precision": to_cpu(precision).item(),
"conf_obj": to_cpu(conf_obj).item(),
"conf_noobj": to_cpu(conf_noobj).item(),
"grid_size": grid_size,
}
return output, total_loss
class Darknet(nn.Module):
"""YOLOv3 object detection model"""
def __init__(self, config_path, img_size=416):
super(Darknet, self).__init__()
self.module_defs = parse_model_config(config_path)
self.hyperparams, self.module_list = create_modules(self.module_defs)
self.yolo_layers = [layer[0] for layer in self.module_list if hasattr(layer[0], "metrics")]
self.img_size = img_size
self.seen = 0
self.header_info = np.array([0, 0, 0, self.seen, 0], dtype=np.int32)
def forward(self, x, targets=None):
img_dim = x.shape[2]
loss = 0
layer_outputs, yolo_outputs = [], []
for i, (module_def, module) in enumerate(zip(self.module_defs, self.module_list)):
if module_def["type"] in ["convolutional", "upsample", "maxpool"]:
x = module(x)
elif module_def["type"] == "route":
x = torch.cat([layer_outputs[int(layer_i)] for layer_i in module_def["layers"].split(",")], 1)
elif module_def["type"] == "shortcut":
layer_i = int(module_def["from"])
x = layer_outputs[-1] + layer_outputs[layer_i]
elif module_def["type"] == "yolo":
x, layer_loss = module[0](x, targets, img_dim)
loss += layer_loss
yolo_outputs.append(x)
layer_outputs.append(x)
yolo_outputs = to_cpu(torch.cat(yolo_outputs, 1))
return yolo_outputs if targets is None else (loss, yolo_outputs)
def load_darknet_weights(self, weights_path):
"""Parses and loads the weights stored in 'weights_path'"""
# Open the weights file
with open(weights_path, "rb") as f:
header = np.fromfile(f, dtype=np.int32, count=5) # First five are header values
self.header_info = header # Needed to write header when saving weights
self.seen = header[3] # number of images seen during training
weights = np.fromfile(f, dtype=np.float32) # The rest are weights
# Establish cutoff for loading backbone weights
cutoff = None
if "darknet53.conv.74" in weights_path:
cutoff = 75
ptr = 0
for i, (module_def, module) in enumerate(zip(self.module_defs, self.module_list)):
if i == cutoff:
break
if module_def["type"] == "convolutional":
conv_layer = module[0]
if module_def["batch_normalize"]:
# Load BN bias, weights, running mean and running variance
bn_layer = module[1]
num_b = bn_layer.bias.numel() # Number of biases
# Bias
bn_b = torch.from_numpy(weights[ptr : ptr + num_b]).view_as(bn_layer.bias)
bn_layer.bias.data.copy_(bn_b)
ptr += num_b
# Weight
bn_w = torch.from_numpy(weights[ptr : ptr + num_b]).view_as(bn_layer.weight)
bn_layer.weight.data.copy_(bn_w)
ptr += num_b
# Running Mean
bn_rm = torch.from_numpy(weights[ptr : ptr + num_b]).view_as(bn_layer.running_mean)
bn_layer.running_mean.data.copy_(bn_rm)
ptr += num_b
# Running Var
bn_rv = torch.from_numpy(weights[ptr : ptr + num_b]).view_as(bn_layer.running_var)
bn_layer.running_var.data.copy_(bn_rv)
ptr += num_b
else:
# Load conv. bias
num_b = conv_layer.bias.numel()
conv_b = torch.from_numpy(weights[ptr : ptr + num_b]).view_as(conv_layer.bias)
conv_layer.bias.data.copy_(conv_b)
ptr += num_b
# Load conv. weights
num_w = conv_layer.weight.numel()
conv_w = torch.from_numpy(weights[ptr : ptr + num_w]).view_as(conv_layer.weight)
conv_layer.weight.data.copy_(conv_w)
ptr += num_w
def save_darknet_weights(self, path, cutoff=-1):
"""
@:param path - path of the new weights file
@:param cutoff - save layers between 0 and cutoff (cutoff = -1 -> all are saved)
"""
fp = open(path, "wb")
self.header_info[3] = self.seen
self.header_info.tofile(fp)
# Iterate through layers
for i, (module_def, module) in enumerate(zip(self.module_defs[:cutoff], self.module_list[:cutoff])):
if module_def["type"] == "convolutional":
conv_layer = module[0]
# If batch norm, load bn first
if module_def["batch_normalize"]:
bn_layer = module[1]
bn_layer.bias.data.cpu().numpy().tofile(fp)
bn_layer.weight.data.cpu().numpy().tofile(fp)
bn_layer.running_mean.data.cpu().numpy().tofile(fp)
bn_layer.running_var.data.cpu().numpy().tofile(fp)
# Load conv bias
else:
conv_layer.bias.data.cpu().numpy().tofile(fp)
# Load conv weights
conv_layer.weight.data.cpu().numpy().tofile(fp)
fp.close()
from __future__ import division
from models import *
from utils.utils import *
from utils.datasets import *
from utils.parse_config import *
import os
import sys
import time
import datetime
import argparse
import tqdm
import torch
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
from torch.autograd import Variable
import torch.optim as optim
def evaluate(model, path, iou_thres, conf_thres, nms_thres, img_size, batch_size):
model.eval()
# Get dataloader
dataset = ListDataset(path, img_size=img_size, augment=False, multiscale=False)
dataloader = torch.utils.data.DataLoader(
dataset, batch_size=batch_size, shuffle=False, num_workers=0, collate_fn=dataset.collate_fn
)
Tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor
labels = []
sample_metrics = [] # List of tuples (TP, confs, pred)
for batch_i, (_, imgs, targets) in enumerate(tqdm.tqdm(dataloader, desc="Detecting objects")):
# Extract labels
labels += targets[:, 1].tolist()
# Rescale target
targets[:, 2:] = xywh2xyxy(targets[:, 2:])
targets[:, 2:] *= img_size
imgs = Variable(imgs.type(Tensor), requires_grad=False)
with torch.no_grad():
outputs = model(imgs)
outputs = non_max_suppression(outputs, conf_thres=conf_thres, nms_thres=nms_thres)
sample_metrics += get_batch_statistics(outputs, targets, iou_threshold=iou_thres)
# Concatenate sample statistics
true_positives, pred_scores, pred_labels = [np.concatenate(x, 0) for x in list(zip(*sample_metrics))]
precision, recall, AP, f1, ap_class = ap_per_class(true_positives, pred_scores, pred_labels, labels)
return precision, recall, AP, f1, ap_class
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--batch_size", type=int, default=2, help="size of each image batch")
parser.add_argument("--model_def", type=str, default="config/yolov3.cfg", help="path to model definition file")
parser.add_argument("--data_config", type=str, default="config/coco.data", help="path to data config file")
parser.add_argument("--weights_path", type=str, default="weights/yolov3.weights", help="path to weights file")
parser.add_argument("--class_path", type=str, default="data/coco.names", help="path to class label file")
parser.add_argument("--iou_thres", type=float, default=0.5, help="iou threshold required to qualify as detected")
parser.add_argument("--conf_thres", type=float, default=0.001, help="object confidence threshold")
parser.add_argument("--nms_thres", type=float, default=0.5, help="iou thresshold for non-maximum suppression")
parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
parser.add_argument("--img_size", type=int, default=416, help="size of each image dimension")
opt = parser.parse_args()
print(opt)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
data_config = parse_data_config(opt.data_config)
valid_path = data_config["valid"]
class_names = load_classes(data_config["names"])
# Initiate model
model = Darknet(opt.model_def).to(device)
if opt.weights_path.endswith(".weights"):
# Load darknet weights
model.load_darknet_weights(opt.weights_path)
else:
# Load checkpoint weights
model.load_state_dict(torch.load(opt.weights_path))
print("Compute mAP...")
precision, recall, AP, f1, ap_class = evaluate(
model,
path=valid_path,
iou_thres=opt.iou_thres,
conf_thres=opt.conf_thres,
nms_thres=opt.nms_thres,
img_size=opt.img_size,
batch_size=2,
)
print("Average Precisions:")
for i, c in enumerate(ap_class):
print(f"+ Class '{c}' ({class_names[c]}) - AP: {AP[i]}")
print(f"mAP: {AP.mean()}")
from __future__ import division
from models import *
from utils.logger import *
from utils.utils import *
from utils.datasets import *
from utils.parse_config import *
from test import evaluate
from terminaltables import AsciiTable
import os
import sys
import time
import datetime
import argparse
import torch
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
from torch.autograd import Variable
import torch.optim as optim
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--epochs", type=int, default=100, help="number of epochs")
parser.add_argument("--batch_size", type=int, default=2, help="size of each image batch")
parser.add_argument("--gradient_accumulations", type=int, default=2, help="number of gradient accums before step")
parser.add_argument("--model_def", type=str, default="config/yolov3.cfg", help="path to model definition file")
parser.add_argument("--data_config", type=str, default="config/coco.data", help="path to data config file")
parser.add_argument("--pretrained_weights", type=str, help="if specified starts from checkpoint model")
parser.add_argument("--n_cpu", type=int, default=0, help="number of cpu threads to use during batch generation")
parser.add_argument("--img_size", type=int, default=416, help="size of each image dimension")
parser.add_argument("--checkpoint_interval", type=int, default=1, help="interval between saving model weights")
parser.add_argument("--evaluation_interval", type=int, default=1, help="interval evaluations on validation set")
parser.add_argument("--compute_map", default=False, help="if True computes mAP every tenth batch")
parser.add_argument("--multiscale_training", default=True, help="allow for multi-scale training")
opt = parser.parse_args()
print(opt)
## logger = Logger("logs")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
os.makedirs("output", exist_ok=True)
os.makedirs("checkpoints", exist_ok=True)
# Get data configuration
data_config = parse_data_config(opt.data_config)
train_path = data_config["train"]
valid_path = data_config["valid"]
class_names = load_classes(data_config["names"])
# Initiate model
model = Darknet(opt.model_def).to(device)
model.apply(weights_init_normal)
# If specified we start from checkpoint
# if opt.pretrained_weights:
# if opt.pretrained_weights.endswith(".pth"):
# model.load_state_dict(torch.load(opt.pretrained_weights))
# else:
# model.load_darknet_weights(opt.pretrained_weights)
# Get dataloader
dataset = ListDataset(train_path, augment=True, multiscale=opt.multiscale_training)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=opt.batch_size,
shuffle=True,
num_workers=opt.n_cpu,
pin_memory=True,
collate_fn=dataset.collate_fn,
)
optimizer = torch.optim.Adam(model.parameters())
metrics = [
"grid_size",
"loss",
"x",
"y",
"w",
"h",
"conf",
"cls",
"cls_acc",
"recall50",
"recall75",
"precision",
"conf_obj",
"conf_noobj",
]
for epoch in range(opt.epochs):
model.train()
start_time = time.time()
for batch_i, (_, imgs, targets) in enumerate(dataloader):
batches_done = len(dataloader) * epoch + batch_i
imgs = Variable(imgs.to(device))
targets = Variable(targets.to(device), requires_grad=False)
loss, outputs = model(imgs, targets)
loss.backward()
if batches_done % opt.gradient_accumulations:
# Accumulates gradient before each step
optimizer.step()
optimizer.zero_grad()
# ----------------
# Log progress
# ----------------
log_str = "\n---- [Epoch %d/%d, Batch %d/%d] ----\n" % (epoch, opt.epochs, batch_i, len(dataloader))
metric_table = [["Metrics", *[f"YOLO Layer {i}" for i in range(len(model.yolo_layers))]]]
# Log metrics at each YOLO layer
for i, metric in enumerate(metrics):
formats = {m: "%.6f" for m in metrics}
formats["grid_size"] = "%2d"
formats["cls_acc"] = "%.2f%%"
row_metrics = [formats[metric] % yolo.metrics.get(metric, 0) for yolo in model.yolo_layers]
metric_table += [[metric, *row_metrics]]
# Tensorboard logging
tensorboard_log = []
for j, yolo in enumerate(model.yolo_layers):
for name, metric in yolo.metrics.items():
if name != "grid_size":
tensorboard_log += [(f"{name}_{j+1}", metric)]
tensorboard_log += [("loss", loss.item())]
## logger.list_of_scalars_summary(tensorboard_log, batches_done)
log_str += AsciiTable(metric_table).table
log_str += f"\nTotal loss {loss.item()}"
# Determine approximate time left for epoch
epoch_batches_left = len(dataloader) - (batch_i + 1)
time_left = datetime.timedelta(seconds=epoch_batches_left * (time.time() - start_time) / (batch_i + 1))
log_str += f"\n---- ETA {time_left}"
print(log_str)
model.seen += imgs.size(0)
if epoch % opt.evaluation_interval == 0:
print("\n---- Evaluating Model ----")
# Evaluate the model on the validation set
precision, recall, AP, f1, ap_class = evaluate(
model,
path=valid_path,
iou_thres=0.5,
conf_thres=0.5,
nms_thres=0.5,
img_size=opt.img_size,
batch_size=2,
)
evaluation_metrics = [
("val_precision", precision.mean()),
("val_recall", recall.mean()),
("val_mAP", AP.mean()),
("val_f1", f1.mean()),
]
## logger.list_of_scalars_summary(evaluation_metrics, epoch)
# Print class APs and mAP
ap_table = [["Index", "Class name", "AP"]]
for i, c in enumerate(ap_class):
ap_table += [[c, class_names[c], "%.5f" % AP[i]]]
print(AsciiTable(ap_table).table)
print(f"---- mAP {AP.mean()}")
if epoch % opt.checkpoint_interval == 0:
torch.save(model.state_dict(), f"checkpoints/yolov3_ckpt_%d.pth" % epoch)
from __future__ import division
from models import *
from utils.logger import *
from utils.utils import *
from utils.datasets import *
from utils.parse_config import *
from test import evaluate
from terminaltables import AsciiTable
import os
import sys
import time
import datetime
import argparse
import torch
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
from torch.autograd import Variable
import torch.optim as optim
import apex
from apex import amp
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--epochs", type=int, default=100, help="number of epochs")
parser.add_argument("--batch_size", type=int, default=2, help="size of each image batch")
parser.add_argument("--gradient_accumulations", type=int, default=2, help="number of gradient accums before step")
parser.add_argument("--model_def", type=str, default="config/yolov3.cfg", help="path to model definition file")
parser.add_argument("--data_config", type=str, default="config/coco.data", help="path to data config file")
parser.add_argument("--pretrained_weights", type=str, help="if specified starts from checkpoint model")
parser.add_argument("--n_cpu", type=int, default=0, help="number of cpu threads to use during batch generation")
parser.add_argument("--img_size", type=int, default=416, help="size of each image dimension")
parser.add_argument("--checkpoint_interval", type=int, default=1, help="interval between saving model weights")
parser.add_argument("--evaluation_interval", type=int, default=1, help="interval evaluations on validation set")
parser.add_argument("--compute_map", default=False, help="if True computes mAP every tenth batch")
parser.add_argument("--multiscale_training", default=True, help="allow for multi-scale training")
opt = parser.parse_args()
print(opt)
## logger = Logger("logs")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
os.makedirs("output", exist_ok=True)
os.makedirs("checkpoints", exist_ok=True)
# Get data configuration
data_config = parse_data_config(opt.data_config)
train_path = data_config["train"]
valid_path = data_config["valid"]
class_names = load_classes(data_config["names"])
# Initiate model
model = Darknet(opt.model_def).to(device)
model.apply(weights_init_normal)
# If specified we start from checkpoint
# if opt.pretrained_weights:
# if opt.pretrained_weights.endswith(".pth"):
# model.load_state_dict(torch.load(opt.pretrained_weights))
# else:
# model.load_darknet_weights(opt.pretrained_weights)
# Get dataloader
dataset = ListDataset(train_path, augment=True, multiscale=opt.multiscale_training)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=opt.batch_size,
shuffle=True,
num_workers=opt.n_cpu,
pin_memory=True,
collate_fn=dataset.collate_fn,
)
optimizer = torch.optim.Adam(model.parameters())
model, optimizer = amp.initialize(model, optimizer, opt_level="O3")
metrics = [
"grid_size",
"loss",
"x",
"y",
"w",
"h",
"conf",
"cls",
"cls_acc",
"recall50",
"recall75",
"precision",
"conf_obj",
"conf_noobj",
]
for epoch in range(opt.epochs):
model.train()
start_time = time.time()
for batch_i, (_, imgs, targets) in enumerate(dataloader):
batches_done = len(dataloader) * epoch + batch_i
imgs = Variable(imgs.to(device))
targets = Variable(targets.to(device), requires_grad=False)
loss, outputs = model(imgs, targets)
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
if batches_done % opt.gradient_accumulations:
# Accumulates gradient before each step
optimizer.step()
optimizer.zero_grad()
# ----------------
# Log progress
# ----------------
log_str = "\n---- [Epoch %d/%d, Batch %d/%d] ----\n" % (epoch, opt.epochs, batch_i, len(dataloader))
metric_table = [["Metrics", *[f"YOLO Layer {i}" for i in range(len(model.yolo_layers))]]]
# Log metrics at each YOLO layer
for i, metric in enumerate(metrics):
formats = {m: "%.6f" for m in metrics}
formats["grid_size"] = "%2d"
formats["cls_acc"] = "%.2f%%"
row_metrics = [formats[metric] % yolo.metrics.get(metric, 0) for yolo in model.yolo_layers]
metric_table += [[metric, *row_metrics]]
# Tensorboard logging
tensorboard_log = []
for j, yolo in enumerate(model.yolo_layers):
for name, metric in yolo.metrics.items():
if name != "grid_size":
tensorboard_log += [(f"{name}_{j+1}", metric)]
tensorboard_log += [("loss", loss.item())]
## logger.list_of_scalars_summary(tensorboard_log, batches_done)
log_str += AsciiTable(metric_table).table
log_str += f"\nTotal loss {loss.item()}"
# Determine approximate time left for epoch
epoch_batches_left = len(dataloader) - (batch_i + 1)
time_left = datetime.timedelta(seconds=epoch_batches_left * (time.time() - start_time) / (batch_i + 1))
log_str += f"\n---- ETA {time_left}"
print(log_str)
model.seen += imgs.size(0)
if epoch % opt.evaluation_interval == 0:
print("\n---- Evaluating Model ----")
# Evaluate the model on the validation set
precision, recall, AP, f1, ap_class = evaluate(
model,
path=valid_path,
iou_thres=0.5,
conf_thres=0.5,
nms_thres=0.5,
img_size=opt.img_size,
batch_size=2,
)
evaluation_metrics = [
("val_precision", precision.mean()),
("val_recall", recall.mean()),
("val_mAP", AP.mean()),
("val_f1", f1.mean()),
]
## logger.list_of_scalars_summary(evaluation_metrics, epoch)
# Print class APs and mAP
ap_table = [["Index", "Class name", "AP"]]
for i, c in enumerate(ap_class):
ap_table += [[c, class_names[c], "%.5f" % AP[i]]]
print(AsciiTable(ap_table).table)
print(f"---- mAP {AP.mean()}")
if epoch % opt.checkpoint_interval == 0:
torch.save(model.state_dict(), f"checkpoints/yolov3_ckpt_%d.pth" % epoch)
import xml.etree.ElementTree as ET
import pickle
import os
from os import listdir, getcwd
from os.path import join
sets=[('2007', 'train'), ('2007', 'val'), ('2007', 'test')]
classes = ["aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"]
def convert(size, box):
dw = 1./(size[0])
dh = 1./(size[1])
x = (box[0] + box[1])/2.0 - 1
y = (box[2] + box[3])/2.0 - 1
w = box[1] - box[0]
h = box[3] - box[2]
x = x*dw
w = w*dw
y = y*dh
h = h*dh
return (x,y,w,h)
def convert_annotation(year, image_id):
in_file = open('VOCdevkit/VOC%s/Annotations/%s.xml'%(year, image_id))
out_file = open('VOCdevkit/VOC%s/labels/%s.txt'%(year, image_id), 'w')
tree=ET.parse(in_file)
root = tree.getroot()
size = root.find('size')
w = int(size.find('width').text)
h = int(size.find('height').text)
for obj in root.iter('object'):
difficult = obj.find('difficult').text
cls = obj.find('name').text
if cls not in classes or int(difficult)==1:
continue
cls_id = classes.index(cls)
xmlbox = obj.find('bndbox')
b = (float(xmlbox.find('xmin').text), float(xmlbox.find('xmax').text), float(xmlbox.find('ymin').text), float(xmlbox.find('ymax').text))
bb = convert((w,h), b)
out_file.write(str(cls_id) + " " + " ".join([str(a) for a in bb]) + '\n')
wd = getcwd()
for year, image_set in sets:
if not os.path.exists('VOCdevkit/VOC%s/labels/'%(year)):
os.makedirs('VOCdevkit/VOC%s/labels/'%(year))
image_ids = open('VOCdevkit/VOC%s/ImageSets/Main/%s.txt'%(year, image_set)).read().strip().split()
list_file = open('%s_%s.txt'%(year, image_set), 'w')
for image_id in image_ids:
list_file.write('%s/VOCdevkit/VOC%s/JPEGImages/%s.jpg\n'%(wd, year, image_id))
convert_annotation(year, image_id)
list_file.close()
os.system("cat 2007_train.txt 2007_val.txt > train.txt")
os.system("cat 2007_train.txt 2007_val.txt 2007_test.txt > train.all.txt")
Display the source blob
Display the rendered blob
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
@wtliao
Copy link

wtliao commented Apr 13, 2020

It is solved. Thanks a lot

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment