Created
January 20, 2019 20:13
-
-
Save justheuristic/41febdb1ad36b55db4de6ae50a2f9f2e to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import tensorflow as tf | |
L = tf.keras.layers | |
class MinFoxSolver: | |
def __init__(self, p, p_b=None, pred_steps=5, gen_steps=1, max_iters=10 ** 5, tolerance=1e-3, | |
optimizer=tf.train.AdamOptimizer(5e-4), | |
make_generator=lambda: L.Dense(1, name='he_who_generates_unpredictable'), | |
make_predictor=lambda: L.Dense(1, name='he_who_predicts_generated_variable'), | |
sess=None, verbose=False, reset_predictor=False, eps=1e-9 | |
): | |
""" | |
Given two matrices A and B, predict a variable f(A) that is impossible to predict from matrix B | |
:param p: last dimension of A | |
:param p_b: dimension of B, default p_b = p | |
:param pred_steps: predictor g(B) training iterations per one training step | |
:param gen_steps: generator f(A) training iterations per one training step | |
:param max_iters: maximum number of optimization steps till termination | |
:param tolerance: terminates if loss difference between 10-iteration cycles reaches this value | |
set to 0 to iterate for max_steps | |
:param optimizer: tf optimizer to be used on both generator and discriminator | |
:param make_generator: callback to create a keras model for target variable generator given A | |
:param make_predictor: callback to create a keras model for target variable predictor given B | |
:param reset_predictor: if True, resets predictor network after every step | |
/* Маааленькая лисёнка */ | |
""" | |
self.session = sess = sess or tf.get_default_session() \ | |
or tf.Session(config=tf.ConfigProto(device_count={'GPU': 0})) | |
self.pred_steps, self.gen_steps = pred_steps, gen_steps | |
self.reset_predictor = reset_predictor | |
self.max_iters, self.tolerance = max_iters, tolerance | |
self.verbose = verbose | |
with sess.as_default(), sess.graph.as_default(): | |
A = self.A = tf.placeholder(tf.float32, [None, p]) | |
B = self.B = tf.placeholder(tf.float32, [None, p_b or p]) | |
self.generator = make_generator() | |
self.predictor = make_predictor() | |
prediction = self.predictor(B) | |
target_raw = self.generator(A) | |
# orthogonalize target and scale to unit norm | |
target = orthogonalize_columns(target_raw) | |
target *= tf.sqrt(tf.to_float(tf.shape(target)[0])) | |
self.loss = self.compute_loss(target, prediction) | |
self.reg = tf.reduce_mean(tf.squared_difference(target, target_raw)) | |
self.update_pred = optimizer.minimize(self.loss, var_list=self.predictor.trainable_variables) | |
self.reset_pred = tf.variables_initializer(self.predictor.trainable_variables) | |
self.update_gen = optimizer.minimize(-self.loss + self.reg, var_list=self.generator.trainable_variables) | |
self.prediction, self.target = prediction, target | |
def compute_loss(self, target, prediction): | |
return tf.reduce_mean(tf.squared_difference(target, prediction)) | |
def fit(self, A, B): | |
sess = self.session | |
with sess.as_default(), sess.graph.as_default(): | |
sess.run(tf.global_variables_initializer()) | |
feed = {self.A: A, self.B: B} | |
prev_loss = sess.run(self.loss, feed) | |
for i in range(1, self.max_iters + 1): | |
for j in range(self.pred_steps): | |
sess.run(self.update_pred, feed) | |
for j in range(self.gen_steps): | |
sess.run(self.update_gen, feed) | |
if i % 100 == 0: | |
loss_i = sess.run(self.loss, feed) | |
if self.verbose: | |
print("step %i; loss=%.3f; delta=%.3f" % (i, loss_i, abs(prev_loss - loss_i))) | |
if abs(prev_loss - loss_i) < self.tolerance: | |
if self.verbose: print("Done: reached target tolerance") | |
break | |
prev_loss = loss_i | |
if self.reset_predictor: | |
sess.run(self.reset_pred) | |
else: | |
if self.verbose: | |
print("Done: reached max steps") | |
return self | |
def predict(self, A=None, B=None): | |
assert (A is None) != (B is None), "Please use either predict(A=...) or predict(B=...)" | |
sess = self.session | |
with sess.as_default(), sess.graph.as_default(): | |
if A is not None: | |
return sess.run(self.target, {self.A: A}) | |
else: | |
return sess.run(self.prediction, {self.B: B}) | |
def get_weights(self): | |
return self.session.run({'generator': self.generator.trainable_variables, | |
'predictor': self.predictor.trainable_variables}) | |
def orthogonalize_rows(matrix): | |
""" Gram-shmidt orthogonalizer; source: https://bit.ly/2FMOp40 """ | |
# add batch dimension for matmul | |
basis = tf.expand_dims(matrix[0, :] / tf.norm(matrix[0, :]), 0) | |
for i in range(1, matrix.shape[0]): | |
v = tf.expand_dims(matrix[i, :], 0) # add batch dimension for matmul | |
w = v - tf.matmul(tf.matmul(v, basis, transpose_b=True), basis) | |
basis = tf.concat([basis, w / tf.norm(w)],axis=0) | |
return basis | |
def orthogonalize_columns(matrix): | |
return tf.transpose(orthogonalize_rows(tf.transpose(matrix))) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment