Skip to content

Instantly share code, notes, and snippets.

@justheuristic
Last active January 21, 2019 18:03
Show Gist options
  • Save justheuristic/bef3b968a2fafcc1e8eb98021e51fd2e to your computer and use it in GitHub Desktop.
Save justheuristic/bef3b968a2fafcc1e8eb98021e51fd2e to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
import numpy as np
from contextlib import contextmanager
from uuid import uuid4
import tensorflow as tf
L = tf.contrib.keras.layers
class MinFoxSolver:
def __init__(self, n_components, p, p_b=None,
gen_optimizer=tf.train.AdamOptimizer(5e-4), pred_optimizer=tf.train.AdamOptimizer(5e-4),
make_generator=lambda n_components: L.Dense(n_components, name='he_who_generates_unpredictable'),
make_predictor=lambda n_components: L.Dense(n_components, name='he_who_predicts_generated_variable'),
sess=None, device=None,
):
"""
Given two matrices A and B, predict a variable f(A) that is impossible to predict from matrix B
:param p: last dimension of A
:param p_b: dimension of B, default p_b = p
:param optimizer: tf optimizer to be used on both generator and discriminator
:param make_generator: callback to create a keras model for target variable generator given A
:param make_predictor: callback to create a keras model for target variable predictor given B
:param sess: tensorflow session. If not specified, uses default session or creates new one.
:param device: 'cpu', 'gpu' or a specific tf device to run on.
/* Маааленькая лисёнка */
"""
config = tf.ConfigProto(device_count={'GPU': int(device != 'cpu')}) if device is not None else tf.ConfigProto()
self.session = sess = sess or tf.get_default_session() or tf.Session()
self.n_components = n_components
self.gen_optimizer, self.pred_optimizer = gen_optimizer, pred_optimizer
with sess.as_default(), sess.graph.as_default(), tf.device(device), tf.variable_scope(str(uuid4())) as self.scope:
A = self.A = tf.placeholder(tf.float32, [None, p])
B = self.B = tf.placeholder(tf.float32, [None, p_b or p])
self.generator = make_generator(n_components)
self.predictor = make_predictor(n_components)
prediction = self.predictor(B)
target_raw = self.generator(A)
# orthogonalize target and scale to unit norm
target = orthogonalize_columns(target_raw)
target *= tf.sqrt(tf.to_float(tf.shape(target)[0]))
self.loss_values = self.compute_loss(target, prediction)
self.loss = tf.reduce_mean(self.loss_values)
self.reg = tf.reduce_mean(tf.squared_difference(target, target_raw))
with tf.variable_scope('gen_optimizer') as gen_optimizer_scope:
self.update_gen = gen_optimizer.minimize(-self.loss + self.reg,
var_list=self.generator.trainable_variables)
with tf.variable_scope('pred_optimizer') as pred_optimizer_scope:
self.update_pred = pred_optimizer.minimize(self.loss,
var_list=self.predictor.trainable_variables)
pred_state = self.predictor.trainable_variables
pred_state += tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=tf.get_variable_scope().name)
self.reset_pred = tf.variables_initializer(pred_state)
self.prediction, self.target = prediction, target
self.target_raw = target_raw
def compute_loss(self, target, prediction):
""" Return loss function for each sample and for component, output.shape == target.shape """
return tf.squared_difference(target, prediction)
def fit(self, A, B, max_iters=10 ** 4, tolerance=1e-4, batch_size=None, pred_steps=5, gen_steps=1,
warm_start=False, reset_predictor=False, reorder=True, verbose=False, report_every=100):
"""
Trains the fox
:param pred_steps: predictor g(B) training iterations per one training step
:param gen_steps: generator f(A) training iterations per one training step
:param max_iters: maximum number of optimization steps till termination
:param tolerance: terminates if loss difference between 10-iteration cycles reaches this value
set to 0 to iterate for max_steps
:param reset_predictor: if True, resets predictor network after every step
:param reorder: if True, reorders components from highest loss to lowest
"""
sess = self.session
step = 0
with sess.as_default(), sess.graph.as_default():
initialize_uninitialized_variables(sess)
prev_loss = float('inf')
for batch_a, batch_b in iterate_minibatches(A, B, batch_size, cycle=True, shuffle=True):
step += 1
if step > max_iters: break
feed = {self.A: batch_a, self.B: batch_b}
# train predictor
for j in range(pred_steps):
sess.run(self.update_pred, feed)
# eval loss and metrics
if step % report_every == 0:
loss_t = sess.run(self.loss, feed)
if verbose:
print("step %i; loss=%.4f; delta=%.4f" % (step, loss_t, abs(prev_loss - loss_t)))
if abs(prev_loss - loss_t) < tolerance:
if verbose: print("Done: reached target tolerance")
break
prev_loss = loss_t
# update generator
for j in range(gen_steps):
sess.run(self.update_gen, feed)
if reset_predictor:
sess.run(self.reset_pred)
else:
if verbose:
print("Done: reached max steps")
# record components ordered by their loss value (highest to lowest)
if reorder:
if verbose:
print("Ordering components by loss values...")
self.loss_per_component = np.zeros([self.n_components])
for batch_a, batch_b in iterate_minibatches(A, B, batch_size, cycle=False, shuffle=False):
batch_loss_values = sess.run(self.loss_values, {self.A: batch_a, self.B: batch_b})
# ^-- [batch_size, n_components]
self.loss_per_component += batch_loss_values.sum(0)
self.component_order = np.argsort(-self.loss_per_component)
self.loss_per_component_ordered = self.loss_per_component[self.component_order]
else:
self.component_order = np.arange(self.n_components)
if verbose:
print("Training finished.")
return self
def predict(self, A=None, B=None, ordered=True, raw=False):
assert (A is None) != (B is None), "Please use either predict(A=...) or predict(B=...)"
sess = self.session
with sess.as_default(), sess.graph.as_default():
if A is not None:
if not raw:
out = sess.run(self.target, {self.A: A})
else:
out = sess.run(self.target_raw, {self.A: A})
else:
out = sess.run(self.prediction, {self.B: B})
if ordered:
out = out[:, self.component_order]
return out
def get_weights(self):
return self.session.run({'generator': self.generator.trainable_variables,
'predictor': self.predictor.trainable_variables})
def orthogonalize_rows(matrix):
"""
Gram-shmidt orthogonalizer for each row of matrix; source: https://bit.ly/2FMOp40
:param matrix: 2d float tensor [nrow, ncol]
:returns: row-orthogonalized matrix [nrow, ncol] s.t.
* output[i, :].dot(output[j, :]) ~= 0 for all i != j
* norm(output[i, :]) == 1 for all i
"""
basis = tf.expand_dims(matrix[0, :] / tf.norm(matrix[0, :]), 0) # [1, ncol]
for i in range(1, matrix.shape[0]):
v = tf.expand_dims(matrix[i, :], 0) # [1, ncol]
w = v - tf.matmul(tf.matmul(v, basis, transpose_b=True), basis) # [1, ncol]
basis = tf.concat([basis, w / tf.norm(w)], axis=0) # [i, ncol]
return basis
def orthogonalize_columns(matrix):
"""
Gram-shmidt orthogonalizer for each row of matrix; source: https://bit.ly/2FMOp40
:param matrix: 2d float tensor [nrow, ncol]
:returns: column-orthogonalized matrix [nrow, ncol] s.t.
* output[:, i].dot(output[:, j]) ~= 0 for all i != j
* norm(output[:, j]) == 1 for all i
"""
basis = tf.expand_dims(matrix[:, 0] / tf.norm(matrix[:, 0]), 1) # [nrow, 1]
for i in range(1, matrix.shape[1]):
v = tf.expand_dims(matrix[:, i], 1) # [nrow, 1]
w = v - tf.matmul(basis, tf.matmul(basis, v, transpose_a=True)) # [nrow, 1]
basis = tf.concat([basis, w / tf.norm(w)], axis=1) # [nrow, i]
return basis
def initialize_uninitialized_variables(sess=None, var_list=None):
with tf.name_scope("initialize"):
sess = sess or tf.get_default_session() or tf.InteractiveSession()
uninitialized_names = set(sess.run(tf.report_uninitialized_variables(var_list)))
uninitialized_vars = []
for var in tf.global_variables():
if var.name[:-2].encode() in uninitialized_names:
uninitialized_vars.append(var)
sess.run(tf.variables_initializer(uninitialized_vars))
def iterate_minibatches(x, y, batch_size=None, cycle=False, shuffle=False):
indices = np.arange(len(x))
while True:
if batch_size is not None:
if shuffle:
indices = np.random.permutation(indices)
for batch_start in range(0, len(x), batch_size):
batch_ix = indices[batch_start: batch_start + batch_size]
yield x[batch_ix], y[batch_ix]
else:
yield x, y
if not cycle:
break
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment