Skip to content

Instantly share code, notes, and snippets.

@SreenivasVRao
Created November 6, 2018 04:18
Show Gist options
  • Save SreenivasVRao/d3036982d0ea443e09b10f8a867bac20 to your computer and use it in GitHub Desktop.
Save SreenivasVRao/d3036982d0ea443e09b10f8a867bac20 to your computer and use it in GitHub Desktop.
Affine Transforms
import numpy as np
def apply_transformations(transformations, data):
# data = T x K x 3 co-ordinates (2 for each point), or K x 3 points in case generate = True
# transformations = T affine matrices 2 x 3 each (T x 2 x 3)
K, _ = data.shape
T, _, _ = transformations.shape # T x 2 x 3 matrices
data = np.broadcast_to(data, (T, K, 3))
data = np.swapaxes(data, 1, 2) # makes it T x 3 x K
new_data = np.zeros([T, 2, K])
for i in range(T):
new_data[i, :, :] = transformations[i, :, :].dot(data[i, :, :]) # 2x3 x 3xK
new_data = np.swapaxes(new_data, 1, 2) # makes it T x K x 2
new_data = np.append(new_data, np.ones([T, K, 1]), axis=2)
return new_data
T = 30 # T different frames
K = 10 # K different points being tracked
translation_radius = 150 # pixels
rotation_radius = np.pi # radians
if K > 10:
raise Exception, 'Choose more colours'
# this is for display purposes
colour_seq = ['tab:blue', 'tab:orange', 'tab:green',
'tab:red', 'tab:purple', 'tab:brown',
'tab:pink', 'tab:gray', 'tab:olive', 'tab:cyan']
start_data = np.random.randint(0, 300, (K, 2))
start_data = np.append(start_data, np.ones([K,1]), axis=1) # convert to homogeneous co-ords
initial_translations = np.zeros([T, 2, 3])
initial_translations[:,0,0] = 1
initial_translations[:,1,1] = 1
for idx, tx in enumerate(range(-75, 75, 5)):
initial_translations[idx, 0, 2] = tx
# initial_translations[idx, 1, 2] = ty
initial_data = apply_transformations(initial_translations, start_data)
initial_data = initial_data.astype(np.float32)
from matplotlib import pyplot as plt
source = np.copy(initial_data).reshape(T, K, 3)
plt.figure(figsize=(8, 8))
for i in range(T):
plt.scatter(x = source[i,:,0], y = source[i,:,1],c =colour_seq, s=60)
plt.title('Translated Points')
plt.xlim(0, 300)
plt.ylim(300, 0)
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment