Skip to content

Instantly share code, notes, and snippets.

@jiqiujia
Last active June 5, 2017 08:55
Show Gist options
  • Save jiqiujia/b39b3e6b00bb86f369fa9c5ee027ff3c to your computer and use it in GitHub Desktop.
Save jiqiujia/b39b3e6b00bb86f369fa9c5ee027ff3c to your computer and use it in GitHub Desktop.
python util functions/preprocessing
def create_conf_matrix(expected, predicted, n_classes):
m = np.zeros((n_classes, n_classes))
for pred, exp in zip(predicted, expected):
m[pred][exp] += 1
return m
def calc_accuracy(conf_matrix):
t = sum(sum(l) for l in conf_matrix)
return sum(conf_matrix[i][i] for i in range(len(conf_matrix))) / t
def getkl(pa, pb):
#calculate kl(pa||pb) = pa*log(pa/pb)
pa = pa.reshape((pa.shape[0], 1, -1))
log_pa = np.log(pa)
log_pb = np.log(pb)
log_pb = log_pb.reshape((1, log_pb.shape[0], -1))
log_pm = log_pa - log_pb
kl = pa * log_pm
return kl.sum(2)
import itertools
def plot_confusion_matrix(cm, classes,
normalize=False,
title='Confusion matrix',
cmap=plt.cm.Blues):
"""
This function prints and plots the confusion matrix.
Normalization can be applied by setting `normalize=True`.
"""
if normalize:
cm = np.around(cm.astype('float') / cm.sum(axis=1)[:, np.newaxis], decimals=3)
print("Normalized confusion matrix")
else:
print('Confusion matrix, without normalization')
print(cm)
plt.imshow(cm, interpolation='nearest', cmap=cmap)
plt.title(title)
plt.colorbar()
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, classes, rotation=45)
plt.yticks(tick_marks, classes)
thresh = cm.max() / 2.
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
if cm[i,j]>=0.001:
plt.text(j, i, ('%.3f' % cm[i, j]).lstrip('0'),
horizontalalignment="center",
color="white" if cm[i, j] > thresh else "black")
plt.tight_layout()
plt.ylabel('True label')
plt.xlabel('Predicted label')
crop_height, crop_width = (128, 64)
padding_h, padding_w = (16, 8)
use_square = False
im_w, imh = 1280, 960
pad_w = 0
pad_h = 0
scale_height = crop_height / (crop_height - padding_h*2.)
scale_width = crop_width / (crop_width - padding_w*2.)
cropped_boxes = copy.deepcopy(boxes)
for i in np.arange(len(cropped_boxes)):
bboxs = cropped_boxes[i]
for j in np.arange(len(bboxs)):
box = bboxs[j]
#print(bboxs[j])
x1, y1, x2, y2 = box
w, h = x2 - x1 + 1, y2 - y1 + 1
half_h, half_w = h/2, w/2
center_x, center_y = x1 + half_w, y1 + half_h
if use_square:
if half_h > half_w:
half_w = half_h
else:
half_h = half_w
x1, x2 = (center_x - half_w * scale_width), (center_x + half_w*scale_width)
y1, y2 = (center_y - half_h * scale_height), (center_y + half_h*scale_height)
h, w = y2 - y1 + 1, x2 - x1 + 1
bboxs[j] = [x1, y1, x2, y2]
#print(bboxs[j])
# pad_x1 = max(0, -x1), pad_y1 = max(0, -y1)
# x1 = max(0, x1), y1 = max(0, y1)
# x2 = min(im_w, x2), y2 = min(im_h, y2),
# clipped_h, clipped_w = y2- y1 + 1, x2 - x1 + 1
#
# scale_x, scale_y = crop_width / w, crop_height / h
# crop_w, crop_h = np.round(clipped_w * scale_x), np.round(clipped_h * scale_y)
# pad_x1, pad_y1 = round(pad_x1 * scale_x), round(pad_y1 * scale_y)
#
# if pad_y1 + crop_height > crop_height:
# crop_height = crop_height - pad_y1
# if pad_x1 + crop_width > crop_width:
# crop_width = crop_width - pad_x1
# -*- coding: utf-8 -*-
import skimage
import numpy as np
from skimage.transform import AffineTransform, warp, rotate
default_augmentation_params = {
'zoom_range': (1/1.1, 1.1),
'rotation_range': (-20, 20),
'shear_range': (0, 0),
'translation_range': (-4, 4),
'do_flip': True,
'allow_stretch': True,
}
def fast_warp(img, tf, output_shape=(28, 28), mode='constant', order=1):
"""
This wrapper function is faster than skimage.transform.warp
"""
m = tf.params # tf._matrix is
return skimage.transform._warps_cy._warp_fast(img, m, output_shape=output_shape, mode=mode, order=order)
def build_centering_transform(image_shape, target_shape=(50, 50)):
rows, cols = image_shape
trows, tcols = target_shape
shift_x = (cols - tcols) / 2.0
shift_y = (rows - trows) / 2.0
return skimage.transform.SimilarityTransform(translation=(shift_x, shift_y))
def build_center_uncenter_transforms(image_shape):
"""
These are used to ensure that zooming and rotation happens around the center of the image.
Use these transforms to center and uncenter the image around such a transform.
"""
center_shift = np.array([image_shape[1], image_shape[0]]) / 2.0 - 0.5 # need to swap rows and cols here apparently! confusing!
tform_uncenter = skimage.transform.SimilarityTransform(translation=-center_shift)
tform_center = skimage.transform.SimilarityTransform(translation=center_shift)
return tform_center, tform_uncenter
def build_augmentation_transform(zoom=(1.0, 1.0), rotation=0, shear=0, translation=(0, 0), flip=False):
if flip:
shear += 180
rotation += 180
# shear by 180 degrees is equivalent to rotation by 180 degrees + flip.
# So after that we rotate it another 180 degrees to get just the flip.
tform_augment = AffineTransform(scale=(1/zoom[0], 1/zoom[1]), rotation=np.deg2rad(rotation), shear=np.deg2rad(shear), translation=translation)
return tform_augment
def random_perturbation_transform(zoom_range, rotation_range, shear_range, translation_range, do_flip=True, allow_stretch=False, rng=np.random):
shift_x = rng.uniform(*translation_range)
shift_y = rng.uniform(*translation_range)
translation = (shift_x, shift_y)
rotation = rng.uniform(*rotation_range)
shear = rng.uniform(*shear_range)
if do_flip:
flip = (rng.randint(2) > 0) # flip half of the time
else:
flip = False
# random zoom
log_zoom_range = [np.log(z) for z in zoom_range]
if isinstance(allow_stretch, float):
log_stretch_range = [-np.log(allow_stretch), np.log(allow_stretch)]
zoom = np.exp(rng.uniform(*log_zoom_range))
stretch = np.exp(rng.uniform(*log_stretch_range))
zoom_x = zoom * stretch
zoom_y = zoom / stretch
elif allow_stretch is True: # avoid bugs, f.e. when it is an integer
zoom_x = np.exp(rng.uniform(*log_zoom_range))
zoom_y = np.exp(rng.uniform(*log_zoom_range))
else:
zoom_x = zoom_y = np.exp(rng.uniform(*log_zoom_range))
# the range should be multiplicatively symmetric, so [1/1.1, 1.1] instead of [0.9, 1.1] makes more sense.
return build_augmentation_transform((zoom_x, zoom_y), rotation, shear, translation, flip)
def perturb(img, augmentation_params=default_augmentation_params, target_shape=(28, 28), rng=np.random):
# # DEBUG: draw a border to see where the image ends up
# img[0, :] = 0.5
# img[-1, :] = 0.5
# img[:, 0] = 0.5
# img[:, -1] = 0.5
tform_centering = build_centering_transform(img.shape, target_shape)
tform_center, tform_uncenter = build_center_uncenter_transforms(img.shape)
tform_augment = random_perturbation_transform(rng=rng, **augmentation_params)
tform_augment = tform_uncenter + tform_augment + tform_center # shift to center, augment, shift back (for the rotation/shearing)
return fast_warp(img, tform_centering + tform_augment, output_shape=target_shape, mode='constant').astype('float32')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment