Created
June 2, 2018 19:02
-
-
Save maorzalt/99e502bd4580e1a7f1a8134101a4e469 to your computer and use it in GitHub Desktop.
Hands-on Demo - Code (from Data Driven Montreal Talk)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import functools | |
import os | |
import PIL.Image | |
import PIL.ImageDraw | |
import PIL.ImageOps | |
import PIL.ImageEnhance | |
import random | |
import numpy as np | |
import cv2 | |
import tensorflow as tf | |
def get_dataset(image_size, images_root, logo_jpg_filename): | |
generator = functools.partial(_data_generator, textures_dir=images_root, logo_filename=logo_jpg_filename) | |
dataset = tf.data.Dataset.from_generator(generator, | |
(tf.float32, | |
tf.float32), | |
(tf.TensorShape(image_size), | |
tf.TensorShape([2]))) | |
return dataset | |
def get_webcam_input_fn(image_size): | |
cap = cv2.VideoCapture(0) | |
def webcam_input_fn(mode): | |
def get_webcam_feed(): | |
while True: | |
ret, frame = cap.read() | |
if not ret: | |
break | |
# Resize frame | |
frame_resized = cv2.resize(frame, dsize=(image_size[1], image_size[0])) | |
# Preprocess frame | |
frame_preprocessed = (2.0 * frame_resized[:, :, ::-1] / 255.0 - 1.0).astype(np.float32) | |
yield frame_preprocessed | |
webcam_dataset = tf.data.Dataset.from_generator(get_webcam_feed, | |
tf.float32, | |
tf.TensorShape(image_size)) | |
webcam_dataset = webcam_dataset.batch(1) | |
return webcam_dataset | |
return cap, webcam_input_fn | |
def _data_generator(random_seed=None, | |
border=0, | |
logo_size_min=70, | |
image_size=(320, 240), | |
textures_dir=None, | |
logo_filename=None): | |
# Thread-safe random | |
local_random = random.Random() | |
local_random.seed(random_seed) | |
# Definitions | |
textures_filenames = [os.path.join(textures_dir, f) for f in os.listdir(textures_dir)] | |
# kernel = np.ones((2, 2), np.uint8) | |
# Load and process logo | |
logo_image_orig = PIL.ImageOps.expand(PIL.Image.open(logo_filename), border=border, fill='white').convert('RGBA') | |
logo_image = PIL.Image.new('RGBA', logo_image_orig.size, color=(255, 255, 255)) | |
logo_image.alpha_composite(logo_image_orig) | |
logo_circle = PIL.Image.new('L', logo_image.size, color=0) | |
draw = PIL.ImageDraw.Draw(logo_circle) | |
for i in range(0, 40): | |
draw.ellipse((i, i, logo_image.size[0]-i, logo_image.size[1]-i), fill=int(255 / 40 * i), outline=0) | |
logo_image.putalpha(logo_circle) | |
while True: | |
# Create background | |
texture_filename = local_random.choice(textures_filenames) | |
texture_image = PIL.Image.open(texture_filename).convert("RGBA") | |
w, h = (0.5 + local_random.random() * 0.5, 0.5 + local_random.random() * 0.5) | |
x1 = local_random.random() * (1 - w) | |
x2 = x1 + w | |
y1 = local_random.random() * (1 - h) | |
y2 = y1 + h | |
background_image = texture_image \ | |
.crop((x1 * texture_image.size[0], y1 * texture_image.size[1], | |
x2 * texture_image.size[0], y2 * texture_image.size[1])) \ | |
.resize(image_size, resample=PIL.Image.BILINEAR) | |
brightness_factor = local_random.random() * 1.6 + 0.2 | |
brightness = PIL.ImageEnhance.Brightness(background_image) | |
background_image = brightness.enhance(brightness_factor) | |
# Create logo to paste | |
brightness_factor = local_random.random() * 1.4 + 0.2 | |
color_factor = local_random.random() | |
angle = local_random.randrange(0, 360) | |
logo_size = local_random.randrange(logo_size_min, min(image_size) * 0.5) | |
logo_origin = [local_random.randrange(0, image_size[0] - logo_size), | |
local_random.randrange(0, image_size[1] - logo_size)] | |
logo_to_paste = logo_image \ | |
.rotate(angle) \ | |
.resize((logo_size, logo_size), resample=PIL.Image.BILINEAR) | |
brightness = PIL.ImageEnhance.Brightness(logo_to_paste) | |
logo_to_paste = brightness.enhance(brightness_factor) | |
color = PIL.ImageEnhance.Color(logo_to_paste) | |
logo_to_paste = color.enhance(color_factor) | |
squeeze_factors = (local_random.random() * 0.3 + 0.7, | |
local_random.random() * 0.3 + 0.7) | |
logo_shift = (int(logo_to_paste.size[0] * (1.0 - squeeze_factors[0]) / 2), | |
int(logo_to_paste.size[1] * (1.0 - squeeze_factors[1]) / 2)) | |
logo_to_paste = logo_to_paste.resize((int(logo_to_paste.size[0] * squeeze_factors[0]), | |
int(logo_to_paste.size[1] * squeeze_factors[1]))) | |
logo_origin[0] += logo_shift[0] | |
logo_origin[1] += logo_shift[1] | |
background_image.paste(logo_to_paste, logo_origin, logo_to_paste) | |
background_image = np.array(background_image.convert('RGB')) | |
targets = np.array([(logo_origin[0] + logo_origin[2]) / 2.0 / image_size[0], | |
(logo_origin[1] + logo_origin[3]) / 2.0 / image_size[1],]) | |
inputs = 2.0 * background_image / 255.0 - 1.0 | |
yield inputs, targets | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from data import get_webcam_input_fn, get_dataset | |
import matplotlib.pyplot as plt | |
import uuid | |
import cv2 | |
import os | |
import tensorflow as tf | |
# Directories - PLEASE CONFIGURE IT | |
models_root = '/.../models' # Output dir for the model and tensorboard logs | |
train_images_root = '/.../train-backgrounds' # Train images (download random images from the internet) | |
valid_images_root = '/.../valid-backgrounds' # Valid images (download random images from the internet) | |
logo_jpg_filename = '/.../logo.jpg' # A jpg of the "Data Driven Montreal" logo (download https://cdn.evbuc.com/images/39269779/90171842189/2/logo.png , print, take a picture using your webcam, crop a square around the logo and save) | |
# Config | |
image_size = (240, 320, 3) | |
batch_size = 16 | |
train_steps = 200 | |
# Model ID | |
model_uuid = uuid.uuid4() | |
model_dir = os.path.join(models_root, f'{model_uuid}') | |
# Input Function | |
def input_fn(mode): | |
if mode == tf.estimator.ModeKeys.TRAIN: | |
dataset = get_dataset(image_size, train_images_root, logo_jpg_filename) | |
else: | |
dataset = get_dataset(image_size, valid_images_root, logo_jpg_filename) | |
dataset = dataset \ | |
.batch(batch_size) \ | |
.prefetch(1) | |
return dataset | |
# Model Function | |
def model_fn(features, labels, mode): | |
with tf.name_scope('Input'): | |
input_tensor = tf.identity(features) | |
with tf.name_scope('Model'): | |
net = input_tensor | |
net = tf.layers.conv2d(net, 32, (3, 3), activation=tf.nn.relu) | |
net = tf.layers.max_pooling2d(net, (2, 2), (2, 2)) | |
net = tf.layers.conv2d(net, 64, (3, 3), activation=tf.nn.relu) | |
net = tf.layers.max_pooling2d(net, (2, 2), (2, 2)) | |
net = tf.layers.conv2d(net, 128, (3, 3), activation=tf.nn.relu) | |
net = tf.layers.max_pooling2d(net, (2, 2), (2, 2)) | |
net = tf.layers.conv2d(net, 128, (3, 3), activation=tf.nn.relu) | |
net = tf.layers.max_pooling2d(net, (2, 2), (2, 2)) | |
net = tf.layers.conv2d(net, 128, (3, 3), activation=tf.nn.relu) | |
net = tf.layers.max_pooling2d(net, (2, 2), (2, 2)) | |
net = tf.layers.conv2d(net, 128, (3, 3), activation=tf.nn.relu) | |
net = tf.layers.max_pooling2d(net, (2, 2), (2, 2)) | |
net = tf.layers.flatten(net) | |
net = tf.layers.dense(net, 256, tf.nn.relu) | |
net = tf.layers.dense(net, 512, tf.nn.relu) | |
net = tf.layers.dense(net, 2, tf.nn.sigmoid, name='Predictions') | |
predictions = {'predictions': net} | |
if mode == tf.estimator.ModeKeys.PREDICT: | |
predictions['image'] = input_tensor | |
return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions) | |
with tf.name_scope('GroundTruth'): | |
labels = tf.identity(labels, name='labels') | |
with tf.name_scope('Loss'): | |
loss = tf.losses.mean_squared_error(labels=labels, predictions=predictions['predictions']) | |
with tf.name_scope('Summaries'): | |
tf.summary.image('images', input_tensor) | |
tf.summary.scalar('loss', loss) | |
tf.summary.histogram('predictions', predictions['predictions']) | |
tf.summary.histogram('targets', labels) | |
if mode == tf.estimator.ModeKeys.EVAL: | |
return tf.estimator.EstimatorSpec(mode=mode, loss=loss) | |
with tf.name_scope('Optimizer'): | |
optimizer = tf.train.AdamOptimizer() | |
updates = tf.get_collection(tf.GraphKeys.UPDATE_OPS) | |
with tf.control_dependencies(updates): | |
train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step()) | |
return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op) | |
# Estimator | |
run_config = tf.estimator.RunConfig( | |
model_dir=model_dir, | |
save_summary_steps=10, | |
save_checkpoints_secs=900, | |
keep_checkpoint_max=1, | |
log_step_count_steps=10 | |
) | |
estimator = tf.estimator.Estimator(model_fn=model_fn, config=run_config) | |
# Training Loop | |
train_spec = tf.estimator.TrainSpec(input_fn=input_fn, max_steps=train_steps) | |
eval_spec = tf.estimator.EvalSpec(input_fn=input_fn, steps=20, name='eval', throttle_secs=30, start_delay_secs=1) | |
tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec) | |
# Visualize predictions | |
predictions = estimator.predict(input_fn) | |
for i, p in zip(range(batch_size), predictions): | |
plt.subplot(4, 4, i+1) | |
img = p['image'] | |
coords = p['predictions'] | |
plt.imshow(img / 2.0 + 0.5) | |
plt.plot(coords[0] * image_size[1], coords[1] * image_size[0], 'or') | |
# Webcam | |
cap, webcam_input_fn = get_webcam_input_fn(image_size) | |
# Predictions | |
predictions = estimator.predict(webcam_input_fn) | |
# Visualize Webcam | |
cv2.namedWindow('Press "q" to quit', cv2.WINDOW_NORMAL) | |
for p in predictions: | |
# Webcam image + Predictions | |
img = (p['image'][:, :, ::-1] + 1.0) / 2.0 | |
coords = p['predictions'] | |
# Draw prediction | |
img = cv2.circle(img, (int(coords[0] * img.shape[1]), int(coords[1] * img.shape[0])), 2, (0, 255, 0), 19) | |
# Display the resulting frame | |
cv2.imshow('Press "q" to quit', img[:, ::-1, :]) | |
# Wait for 'q' click | |
if cv2.waitKey(1) & 0xFF == ord('q'): | |
cap.release() | |
cv2.destroyAllWindows() | |
break | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment