Last active
June 5, 2017 08:55
-
-
Save jiqiujia/b39b3e6b00bb86f369fa9c5ee027ff3c to your computer and use it in GitHub Desktop.
python util functions/preprocessing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def create_conf_matrix(expected, predicted, n_classes): | |
m = np.zeros((n_classes, n_classes)) | |
for pred, exp in zip(predicted, expected): | |
m[pred][exp] += 1 | |
return m | |
def calc_accuracy(conf_matrix): | |
t = sum(sum(l) for l in conf_matrix) | |
return sum(conf_matrix[i][i] for i in range(len(conf_matrix))) / t | |
def getkl(pa, pb): | |
#calculate kl(pa||pb) = pa*log(pa/pb) | |
pa = pa.reshape((pa.shape[0], 1, -1)) | |
log_pa = np.log(pa) | |
log_pb = np.log(pb) | |
log_pb = log_pb.reshape((1, log_pb.shape[0], -1)) | |
log_pm = log_pa - log_pb | |
kl = pa * log_pm | |
return kl.sum(2) | |
import itertools | |
def plot_confusion_matrix(cm, classes, | |
normalize=False, | |
title='Confusion matrix', | |
cmap=plt.cm.Blues): | |
""" | |
This function prints and plots the confusion matrix. | |
Normalization can be applied by setting `normalize=True`. | |
""" | |
if normalize: | |
cm = np.around(cm.astype('float') / cm.sum(axis=1)[:, np.newaxis], decimals=3) | |
print("Normalized confusion matrix") | |
else: | |
print('Confusion matrix, without normalization') | |
print(cm) | |
plt.imshow(cm, interpolation='nearest', cmap=cmap) | |
plt.title(title) | |
plt.colorbar() | |
tick_marks = np.arange(len(classes)) | |
plt.xticks(tick_marks, classes, rotation=45) | |
plt.yticks(tick_marks, classes) | |
thresh = cm.max() / 2. | |
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): | |
if cm[i,j]>=0.001: | |
plt.text(j, i, ('%.3f' % cm[i, j]).lstrip('0'), | |
horizontalalignment="center", | |
color="white" if cm[i, j] > thresh else "black") | |
plt.tight_layout() | |
plt.ylabel('True label') | |
plt.xlabel('Predicted label') |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
crop_height, crop_width = (128, 64) | |
padding_h, padding_w = (16, 8) | |
use_square = False | |
im_w, imh = 1280, 960 | |
pad_w = 0 | |
pad_h = 0 | |
scale_height = crop_height / (crop_height - padding_h*2.) | |
scale_width = crop_width / (crop_width - padding_w*2.) | |
cropped_boxes = copy.deepcopy(boxes) | |
for i in np.arange(len(cropped_boxes)): | |
bboxs = cropped_boxes[i] | |
for j in np.arange(len(bboxs)): | |
box = bboxs[j] | |
#print(bboxs[j]) | |
x1, y1, x2, y2 = box | |
w, h = x2 - x1 + 1, y2 - y1 + 1 | |
half_h, half_w = h/2, w/2 | |
center_x, center_y = x1 + half_w, y1 + half_h | |
if use_square: | |
if half_h > half_w: | |
half_w = half_h | |
else: | |
half_h = half_w | |
x1, x2 = (center_x - half_w * scale_width), (center_x + half_w*scale_width) | |
y1, y2 = (center_y - half_h * scale_height), (center_y + half_h*scale_height) | |
h, w = y2 - y1 + 1, x2 - x1 + 1 | |
bboxs[j] = [x1, y1, x2, y2] | |
#print(bboxs[j]) | |
# pad_x1 = max(0, -x1), pad_y1 = max(0, -y1) | |
# x1 = max(0, x1), y1 = max(0, y1) | |
# x2 = min(im_w, x2), y2 = min(im_h, y2), | |
# clipped_h, clipped_w = y2- y1 + 1, x2 - x1 + 1 | |
# | |
# scale_x, scale_y = crop_width / w, crop_height / h | |
# crop_w, crop_h = np.round(clipped_w * scale_x), np.round(clipped_h * scale_y) | |
# pad_x1, pad_y1 = round(pad_x1 * scale_x), round(pad_y1 * scale_y) | |
# | |
# if pad_y1 + crop_height > crop_height: | |
# crop_height = crop_height - pad_y1 | |
# if pad_x1 + crop_width > crop_width: | |
# crop_width = crop_width - pad_x1 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
import skimage | |
import numpy as np | |
from skimage.transform import AffineTransform, warp, rotate | |
default_augmentation_params = { | |
'zoom_range': (1/1.1, 1.1), | |
'rotation_range': (-20, 20), | |
'shear_range': (0, 0), | |
'translation_range': (-4, 4), | |
'do_flip': True, | |
'allow_stretch': True, | |
} | |
def fast_warp(img, tf, output_shape=(28, 28), mode='constant', order=1): | |
""" | |
This wrapper function is faster than skimage.transform.warp | |
""" | |
m = tf.params # tf._matrix is | |
return skimage.transform._warps_cy._warp_fast(img, m, output_shape=output_shape, mode=mode, order=order) | |
def build_centering_transform(image_shape, target_shape=(50, 50)): | |
rows, cols = image_shape | |
trows, tcols = target_shape | |
shift_x = (cols - tcols) / 2.0 | |
shift_y = (rows - trows) / 2.0 | |
return skimage.transform.SimilarityTransform(translation=(shift_x, shift_y)) | |
def build_center_uncenter_transforms(image_shape): | |
""" | |
These are used to ensure that zooming and rotation happens around the center of the image. | |
Use these transforms to center and uncenter the image around such a transform. | |
""" | |
center_shift = np.array([image_shape[1], image_shape[0]]) / 2.0 - 0.5 # need to swap rows and cols here apparently! confusing! | |
tform_uncenter = skimage.transform.SimilarityTransform(translation=-center_shift) | |
tform_center = skimage.transform.SimilarityTransform(translation=center_shift) | |
return tform_center, tform_uncenter | |
def build_augmentation_transform(zoom=(1.0, 1.0), rotation=0, shear=0, translation=(0, 0), flip=False): | |
if flip: | |
shear += 180 | |
rotation += 180 | |
# shear by 180 degrees is equivalent to rotation by 180 degrees + flip. | |
# So after that we rotate it another 180 degrees to get just the flip. | |
tform_augment = AffineTransform(scale=(1/zoom[0], 1/zoom[1]), rotation=np.deg2rad(rotation), shear=np.deg2rad(shear), translation=translation) | |
return tform_augment | |
def random_perturbation_transform(zoom_range, rotation_range, shear_range, translation_range, do_flip=True, allow_stretch=False, rng=np.random): | |
shift_x = rng.uniform(*translation_range) | |
shift_y = rng.uniform(*translation_range) | |
translation = (shift_x, shift_y) | |
rotation = rng.uniform(*rotation_range) | |
shear = rng.uniform(*shear_range) | |
if do_flip: | |
flip = (rng.randint(2) > 0) # flip half of the time | |
else: | |
flip = False | |
# random zoom | |
log_zoom_range = [np.log(z) for z in zoom_range] | |
if isinstance(allow_stretch, float): | |
log_stretch_range = [-np.log(allow_stretch), np.log(allow_stretch)] | |
zoom = np.exp(rng.uniform(*log_zoom_range)) | |
stretch = np.exp(rng.uniform(*log_stretch_range)) | |
zoom_x = zoom * stretch | |
zoom_y = zoom / stretch | |
elif allow_stretch is True: # avoid bugs, f.e. when it is an integer | |
zoom_x = np.exp(rng.uniform(*log_zoom_range)) | |
zoom_y = np.exp(rng.uniform(*log_zoom_range)) | |
else: | |
zoom_x = zoom_y = np.exp(rng.uniform(*log_zoom_range)) | |
# the range should be multiplicatively symmetric, so [1/1.1, 1.1] instead of [0.9, 1.1] makes more sense. | |
return build_augmentation_transform((zoom_x, zoom_y), rotation, shear, translation, flip) | |
def perturb(img, augmentation_params=default_augmentation_params, target_shape=(28, 28), rng=np.random): | |
# # DEBUG: draw a border to see where the image ends up | |
# img[0, :] = 0.5 | |
# img[-1, :] = 0.5 | |
# img[:, 0] = 0.5 | |
# img[:, -1] = 0.5 | |
tform_centering = build_centering_transform(img.shape, target_shape) | |
tform_center, tform_uncenter = build_center_uncenter_transforms(img.shape) | |
tform_augment = random_perturbation_transform(rng=rng, **augmentation_params) | |
tform_augment = tform_uncenter + tform_augment + tform_center # shift to center, augment, shift back (for the rotation/shearing) | |
return fast_warp(img, tform_centering + tform_augment, output_shape=target_shape, mode='constant').astype('float32') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment