Last active
January 21, 2019 18:03
-
-
Save justheuristic/bef3b968a2fafcc1e8eb98021e51fd2e to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from contextlib import contextmanager | |
from uuid import uuid4 | |
import tensorflow as tf | |
L = tf.contrib.keras.layers | |
class MinFoxSolver: | |
def __init__(self, n_components, p, p_b=None, | |
gen_optimizer=tf.train.AdamOptimizer(5e-4), pred_optimizer=tf.train.AdamOptimizer(5e-4), | |
make_generator=lambda n_components: L.Dense(n_components, name='he_who_generates_unpredictable'), | |
make_predictor=lambda n_components: L.Dense(n_components, name='he_who_predicts_generated_variable'), | |
sess=None, device=None, | |
): | |
""" | |
Given two matrices A and B, predict a variable f(A) that is impossible to predict from matrix B | |
:param p: last dimension of A | |
:param p_b: dimension of B, default p_b = p | |
:param optimizer: tf optimizer to be used on both generator and discriminator | |
:param make_generator: callback to create a keras model for target variable generator given A | |
:param make_predictor: callback to create a keras model for target variable predictor given B | |
:param sess: tensorflow session. If not specified, uses default session or creates new one. | |
:param device: 'cpu', 'gpu' or a specific tf device to run on. | |
/* Маааленькая лисёнка */ | |
""" | |
config = tf.ConfigProto(device_count={'GPU': int(device != 'cpu')}) if device is not None else tf.ConfigProto() | |
self.session = sess = sess or tf.get_default_session() or tf.Session() | |
self.n_components = n_components | |
self.gen_optimizer, self.pred_optimizer = gen_optimizer, pred_optimizer | |
with sess.as_default(), sess.graph.as_default(), tf.device(device), tf.variable_scope(str(uuid4())) as self.scope: | |
A = self.A = tf.placeholder(tf.float32, [None, p]) | |
B = self.B = tf.placeholder(tf.float32, [None, p_b or p]) | |
self.generator = make_generator(n_components) | |
self.predictor = make_predictor(n_components) | |
prediction = self.predictor(B) | |
target_raw = self.generator(A) | |
# orthogonalize target and scale to unit norm | |
target = orthogonalize_columns(target_raw) | |
target *= tf.sqrt(tf.to_float(tf.shape(target)[0])) | |
self.loss_values = self.compute_loss(target, prediction) | |
self.loss = tf.reduce_mean(self.loss_values) | |
self.reg = tf.reduce_mean(tf.squared_difference(target, target_raw)) | |
with tf.variable_scope('gen_optimizer') as gen_optimizer_scope: | |
self.update_gen = gen_optimizer.minimize(-self.loss + self.reg, | |
var_list=self.generator.trainable_variables) | |
with tf.variable_scope('pred_optimizer') as pred_optimizer_scope: | |
self.update_pred = pred_optimizer.minimize(self.loss, | |
var_list=self.predictor.trainable_variables) | |
pred_state = self.predictor.trainable_variables | |
pred_state += tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=tf.get_variable_scope().name) | |
self.reset_pred = tf.variables_initializer(pred_state) | |
self.prediction, self.target = prediction, target | |
self.target_raw = target_raw | |
def compute_loss(self, target, prediction): | |
""" Return loss function for each sample and for component, output.shape == target.shape """ | |
return tf.squared_difference(target, prediction) | |
def fit(self, A, B, max_iters=10 ** 4, tolerance=1e-4, batch_size=None, pred_steps=5, gen_steps=1, | |
warm_start=False, reset_predictor=False, reorder=True, verbose=False, report_every=100): | |
""" | |
Trains the fox | |
:param pred_steps: predictor g(B) training iterations per one training step | |
:param gen_steps: generator f(A) training iterations per one training step | |
:param max_iters: maximum number of optimization steps till termination | |
:param tolerance: terminates if loss difference between 10-iteration cycles reaches this value | |
set to 0 to iterate for max_steps | |
:param reset_predictor: if True, resets predictor network after every step | |
:param reorder: if True, reorders components from highest loss to lowest | |
""" | |
sess = self.session | |
step = 0 | |
with sess.as_default(), sess.graph.as_default(): | |
initialize_uninitialized_variables(sess) | |
prev_loss = float('inf') | |
for batch_a, batch_b in iterate_minibatches(A, B, batch_size, cycle=True, shuffle=True): | |
step += 1 | |
if step > max_iters: break | |
feed = {self.A: batch_a, self.B: batch_b} | |
# train predictor | |
for j in range(pred_steps): | |
sess.run(self.update_pred, feed) | |
# eval loss and metrics | |
if step % report_every == 0: | |
loss_t = sess.run(self.loss, feed) | |
if verbose: | |
print("step %i; loss=%.4f; delta=%.4f" % (step, loss_t, abs(prev_loss - loss_t))) | |
if abs(prev_loss - loss_t) < tolerance: | |
if verbose: print("Done: reached target tolerance") | |
break | |
prev_loss = loss_t | |
# update generator | |
for j in range(gen_steps): | |
sess.run(self.update_gen, feed) | |
if reset_predictor: | |
sess.run(self.reset_pred) | |
else: | |
if verbose: | |
print("Done: reached max steps") | |
# record components ordered by their loss value (highest to lowest) | |
if reorder: | |
if verbose: | |
print("Ordering components by loss values...") | |
self.loss_per_component = np.zeros([self.n_components]) | |
for batch_a, batch_b in iterate_minibatches(A, B, batch_size, cycle=False, shuffle=False): | |
batch_loss_values = sess.run(self.loss_values, {self.A: batch_a, self.B: batch_b}) | |
# ^-- [batch_size, n_components] | |
self.loss_per_component += batch_loss_values.sum(0) | |
self.component_order = np.argsort(-self.loss_per_component) | |
self.loss_per_component_ordered = self.loss_per_component[self.component_order] | |
else: | |
self.component_order = np.arange(self.n_components) | |
if verbose: | |
print("Training finished.") | |
return self | |
def predict(self, A=None, B=None, ordered=True, raw=False): | |
assert (A is None) != (B is None), "Please use either predict(A=...) or predict(B=...)" | |
sess = self.session | |
with sess.as_default(), sess.graph.as_default(): | |
if A is not None: | |
if not raw: | |
out = sess.run(self.target, {self.A: A}) | |
else: | |
out = sess.run(self.target_raw, {self.A: A}) | |
else: | |
out = sess.run(self.prediction, {self.B: B}) | |
if ordered: | |
out = out[:, self.component_order] | |
return out | |
def get_weights(self): | |
return self.session.run({'generator': self.generator.trainable_variables, | |
'predictor': self.predictor.trainable_variables}) | |
def orthogonalize_rows(matrix): | |
""" | |
Gram-shmidt orthogonalizer for each row of matrix; source: https://bit.ly/2FMOp40 | |
:param matrix: 2d float tensor [nrow, ncol] | |
:returns: row-orthogonalized matrix [nrow, ncol] s.t. | |
* output[i, :].dot(output[j, :]) ~= 0 for all i != j | |
* norm(output[i, :]) == 1 for all i | |
""" | |
basis = tf.expand_dims(matrix[0, :] / tf.norm(matrix[0, :]), 0) # [1, ncol] | |
for i in range(1, matrix.shape[0]): | |
v = tf.expand_dims(matrix[i, :], 0) # [1, ncol] | |
w = v - tf.matmul(tf.matmul(v, basis, transpose_b=True), basis) # [1, ncol] | |
basis = tf.concat([basis, w / tf.norm(w)], axis=0) # [i, ncol] | |
return basis | |
def orthogonalize_columns(matrix): | |
""" | |
Gram-shmidt orthogonalizer for each row of matrix; source: https://bit.ly/2FMOp40 | |
:param matrix: 2d float tensor [nrow, ncol] | |
:returns: column-orthogonalized matrix [nrow, ncol] s.t. | |
* output[:, i].dot(output[:, j]) ~= 0 for all i != j | |
* norm(output[:, j]) == 1 for all i | |
""" | |
basis = tf.expand_dims(matrix[:, 0] / tf.norm(matrix[:, 0]), 1) # [nrow, 1] | |
for i in range(1, matrix.shape[1]): | |
v = tf.expand_dims(matrix[:, i], 1) # [nrow, 1] | |
w = v - tf.matmul(basis, tf.matmul(basis, v, transpose_a=True)) # [nrow, 1] | |
basis = tf.concat([basis, w / tf.norm(w)], axis=1) # [nrow, i] | |
return basis | |
def initialize_uninitialized_variables(sess=None, var_list=None): | |
with tf.name_scope("initialize"): | |
sess = sess or tf.get_default_session() or tf.InteractiveSession() | |
uninitialized_names = set(sess.run(tf.report_uninitialized_variables(var_list))) | |
uninitialized_vars = [] | |
for var in tf.global_variables(): | |
if var.name[:-2].encode() in uninitialized_names: | |
uninitialized_vars.append(var) | |
sess.run(tf.variables_initializer(uninitialized_vars)) | |
def iterate_minibatches(x, y, batch_size=None, cycle=False, shuffle=False): | |
indices = np.arange(len(x)) | |
while True: | |
if batch_size is not None: | |
if shuffle: | |
indices = np.random.permutation(indices) | |
for batch_start in range(0, len(x), batch_size): | |
batch_ix = indices[batch_start: batch_start + batch_size] | |
yield x[batch_ix], y[batch_ix] | |
else: | |
yield x, y | |
if not cycle: | |
break |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment