Created
January 24, 2019 09:52
-
-
Save justheuristic/1e90f65371bed3073ad9fb1abb957c42 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from contextlib import contextmanager | |
from uuid import uuid4 | |
import tensorflow as tf | |
import tfnn.layers.basic as L | |
class MinFoxSolver: | |
def __init__(self, n_components, p, p_b=None, double_grad_steps=0, double_grad_lr=0.01, | |
gen_optimizer=tf.train.AdamOptimizer(5e-4), pred_optimizer=tf.train.AdamOptimizer(5e-4), | |
make_generator=lambda name, inp_size, out_size: L.Dense(name, inp_size, out_size, activ=lambda x: x), | |
make_predictor=lambda name, inp_size, out_size: L.Dense(name, inp_size, out_size, activ=lambda x: x, | |
matrix_initializer=tf.zeros_initializer()), | |
sess=None, device=None, | |
): | |
""" | |
Given two matrices A and B, predict a variable f(A) that is impossible to predict from matrix B | |
:param p: last dimension of A | |
:param p_b: dimension of B, default p_b = p | |
:param gen_optimizer: tf optimizer used to train generator | |
:param pred_optimizer: tf optimizer used to train predictor out-of-graph (between iterations) | |
:param make_generator: callback to create a keras model for target variable generator given A | |
:param make_predictor: callback to create a keras model for target variable predictor given B | |
Both models returned by make_generator and make_predictor should have | |
all their trainable variables created inside name scope (first arg) | |
:param sess: tensorflow session. If not specified, uses default session or creates new one. | |
:param device: 'cpu', 'gpu' or a specific tf device to run on. | |
/* Маааленькая лисёнка */ | |
""" | |
config = tf.ConfigProto(device_count={'GPU': int(device != 'cpu')}) if device is not None else tf.ConfigProto() | |
self.session = sess = sess or tf.get_default_session() or tf.Session() | |
self.n_components = n_components | |
self.gen_optimizer, self.pred_optimizer = gen_optimizer, pred_optimizer | |
with sess.as_default(), sess.graph.as_default(), tf.device(device), tf.variable_scope(str(uuid4())) as self.scope: | |
A = self.A = tf.placeholder(tf.float32, [None, p]) | |
B = self.B = tf.placeholder(tf.float32, [None, p_b or p]) | |
self.generator = make_generator('generator', p, n_components) | |
self.predictor = make_predictor('predictor', p_b or p, n_components) | |
prefix = tf.get_variable_scope().name + '/' | |
self.generator_weights = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, | |
scope=prefix + 'generator') | |
self.predictor_weights = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, | |
scope=prefix + 'predictor') | |
prediction = self.predictor(B) | |
target_raw = self.generator(A) | |
# orthogonalize target and scale to unit norm | |
target = orthogonalize_columns(target_raw) | |
target *= tf.sqrt(tf.to_float(tf.shape(target)[0])) | |
self.loss_values = self.compute_loss(target, prediction) | |
self.loss = tf.reduce_mean(self.loss_values) | |
# update predictor | |
with tf.variable_scope('pred_optimizer') as pred_optimizer_scope: | |
self.update_pred = pred_optimizer.minimize(self.loss, | |
var_list=self.predictor_weights) | |
pred_state = list(self.predictor_weights) | |
pred_state += tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=tf.get_variable_scope().name) | |
self.reset_pred = tf.variables_initializer(pred_state) | |
# update generator using (possibly) ingraph-updated generator | |
if double_grad_steps == 0: | |
generator_loss = self.loss | |
else: | |
make_model = lambda: make_predictor('predictor', p_b or p, self.n_components) | |
get_loss = lambda predictor: tf.reduce_mean(self.compute_loss(target, predictor(B))) | |
self.updated_predictor = get_updated_model(self.predictor, make_model, get_loss, | |
n_steps=double_grad_steps, learning_rate=double_grad_lr, | |
model_variables=self.predictor_weights) | |
self.generator_loss = get_loss(self.updated_predictor) | |
self.generator_reg = tf.reduce_mean(tf.squared_difference(target, target_raw)) | |
with tf.variable_scope('gen_optimizer') as gen_optimizer_scope: | |
self.update_gen = gen_optimizer.minimize(-self.generator_loss + self.generator_reg, | |
var_list=self.generator_weights) | |
self.prediction, self.target = prediction, target | |
def compute_loss(self, target, prediction): | |
""" Return loss function for each sample and for component, output.shape == target.shape """ | |
return tf.squared_difference(target, prediction) | |
def fit(self, A, B, max_iters=10 ** 4, min_iters=0, tolerance=1e-4, batch_size=None, pred_steps=5, gen_steps=1, | |
reset_predictor=False, reorder=True, verbose=False, report_every=100): | |
""" | |
Trains the fox | |
:param pred_steps: predictor g(B) training iterations per one training step | |
:param gen_steps: generator f(A) training iterations per one training step | |
:param max_iters: maximum number of optimization steps till termination | |
:param min_iters: the guaranteed number of steps that algorithm will make without terminating by tolerance | |
:param tolerance: terminates if loss difference between 10-iteration cycles reaches this value | |
set to 0 to iterate for max_steps | |
:param reset_predictor: if True, resets predictor network after every step | |
:param reorder: if True, reorders components from highest loss to lowest | |
""" | |
sess = self.session | |
step = 0 | |
assert gen_steps > 0 | |
with sess.as_default(), sess.graph.as_default(): | |
initialize_uninitialized_variables(sess) | |
prev_loss = float('inf') | |
for batch_a, batch_b in iterate_minibatches(A, B, batch_size, cycle=True, shuffle=True): | |
step += 1 | |
if step > max_iters: break | |
feed = {self.A: batch_a, self.B: batch_b} | |
# update generator | |
for j in range(gen_steps): | |
loss_t_gen, _ = sess.run([self.generator_loss, self.update_gen], feed) | |
if reset_predictor: | |
sess.run(self.reset_pred) | |
for j in range(pred_steps): | |
loss_t_pred, _ = sess.run([self.loss, self.update_pred], feed) | |
# eval loss and metrics | |
if step % report_every == 0: | |
if pred_steps == 0: | |
loss_t_pred = sess.run(self.loss, feed) | |
loss_delta = abs(prev_loss - loss_t_gen) | |
prev_loss = loss_t_gen | |
if verbose: | |
if pred_steps == 0: | |
print("step %i; loss=%.4f; delta=%.4f" % (step, loss_t_gen, loss_delta)) | |
else: | |
print("step %i; loss(gen update)=%.4f; loss(pred update)=%.4f; delta=%.4f" % ( | |
step, loss_t_gen, loss_t_pred, loss_delta)) | |
if loss_delta < tolerance and step > min_iters: | |
if verbose: print("Done: reached target tolerance") | |
break | |
else: | |
if verbose: | |
print("Done: reached max steps") | |
# record components ordered by their loss value (highest to lowest) | |
if reorder: | |
if verbose: | |
print("Ordering components by loss values...") | |
self.loss_per_component = np.zeros([self.n_components]) | |
for batch_a, batch_b in iterate_minibatches(A, B, batch_size, cycle=False, shuffle=False): | |
batch_loss_values = sess.run(self.loss_values, {self.A: batch_a, self.B: batch_b}) | |
# ^-- [batch_size, n_components] | |
self.loss_per_component += batch_loss_values.sum(0) | |
self.component_order = np.argsort(-self.loss_per_component) | |
self.loss_per_component_ordered = self.loss_per_component[self.component_order] | |
else: | |
self.component_order = np.arange(self.n_components) | |
if verbose: | |
print("Training finished.") | |
return self | |
def predict(self, A=None, B=None, ordered=True, raw=False): | |
assert (A is None) != (B is None), "Please use either predict(A=...) or predict(B=...)" | |
sess = self.session | |
with sess.as_default(), sess.graph.as_default(): | |
if A is not None: | |
if not raw: | |
out = sess.run(self.target, {self.A: A}) | |
else: | |
out = sess.run(self.target_raw, {self.A: A}) | |
else: | |
out = sess.run(self.prediction, {self.B: B}) | |
if ordered: | |
out = out[:, self.component_order] | |
return out | |
def get_weights(self): | |
return self.session.run({'generator': self.generator.trainable_variables, | |
'predictor': self.predictor.trainable_variables}) | |
class DoubleFoxSolver(MinFoxSolver): | |
def __init__(self, n_components, p, p_b=None, double_grad_steps=10, double_grad_lr=0.01, | |
gen_optimizer=tf.train.AdamOptimizer(5e-4), **kwargs): | |
""" | |
A wrapper for MinFoxSolver that works by double gradient without any other predictor updates. | |
""" | |
dummy_opt = tf.train.GradientDescentOptimizer(learning_rate=0.0) | |
super().__init__(n_components, p, p_b, double_grad_steps=double_grad_steps, double_grad_lr=double_grad_lr, | |
gen_optimizer=gen_optimizer, pred_optimizer=dummy_opt, **kwargs) | |
def fit(self, A, B, max_iters=10 ** 4, min_iters=0, tolerance=1e-4, batch_size=None, | |
reset_predictor=True, reorder=True, verbose=False, report_every=100, **kwargs): | |
""" | |
A wrapper for MinFoxSolver that works by double gradient without any other predictor updates. | |
""" | |
super().fit(A, B, max_iters, min_iters, tolerance, batch_size=batch_size, | |
gen_steps=1, pred_steps=0, reset_predictor=reset_predictor, | |
reorder=reorder, verbose=verbose, report_every=report_every, **kwargs) | |
DoubleFoxSolver.__init__.__doc__ += MinFoxSolver.__init__.__doc__ | |
DoubleFoxSolver.fit.__doc__ += MinFoxSolver.fit.__doc__ | |
def orthogonalize_rows(matrix): | |
""" | |
Gram-shmidt orthogonalizer for each row of matrix; source: https://bit.ly/2FMOp40 | |
:param matrix: 2d float tensor [nrow, ncol] | |
:returns: row-orthogonalized matrix [nrow, ncol] s.t. | |
* output[i, :].dot(output[j, :]) ~= 0 for all i != j | |
* norm(output[i, :]) == 1 for all i | |
""" | |
basis = tf.expand_dims(matrix[0, :] / tf.norm(matrix[0, :]), 0) # [1, ncol] | |
for i in range(1, matrix.shape[0]): | |
v = tf.expand_dims(matrix[i, :], 0) # [1, ncol] | |
w = v - tf.matmul(tf.matmul(v, basis, transpose_b=True), basis) # [1, ncol] | |
basis = tf.concat([basis, w / tf.norm(w)], axis=0) # [i, ncol] | |
return basis | |
def orthogonalize_columns(matrix): | |
""" | |
Gram-shmidt orthogonalizer for each row of matrix; source: https://bit.ly/2FMOp40 | |
:param matrix: 2d float tensor [nrow, ncol] | |
:returns: column-orthogonalized matrix [nrow, ncol] s.t. | |
* output[:, i].dot(output[:, j]) ~= 0 for all i != j | |
* norm(output[:, j]) == 1 for all i | |
""" | |
basis = tf.expand_dims(matrix[:, 0] / tf.norm(matrix[:, 0]), 1) # [nrow, 1] | |
for i in range(1, matrix.shape[1]): | |
v = tf.expand_dims(matrix[:, i], 1) # [nrow, 1] | |
w = v - tf.matmul(basis, tf.matmul(basis, v, transpose_a=True)) # [nrow, 1] | |
basis = tf.concat([basis, w / tf.norm(w)], axis=1) # [nrow, i] | |
return basis | |
def initialize_uninitialized_variables(sess=None, var_list=None): | |
with tf.name_scope("initialize"): | |
sess = sess or tf.get_default_session() or tf.InteractiveSession() | |
uninitialized_names = set(sess.run(tf.report_uninitialized_variables(var_list))) | |
uninitialized_vars = [] | |
for var in tf.global_variables(): | |
if var.name[:-2].encode() in uninitialized_names: | |
uninitialized_vars.append(var) | |
sess.run(tf.variables_initializer(uninitialized_vars)) | |
def iterate_minibatches(x, y, batch_size=None, cycle=False, shuffle=False): | |
indices = np.arange(len(x)) | |
while True: | |
if batch_size is not None: | |
if shuffle: | |
indices = np.random.permutation(indices) | |
for batch_start in range(0, len(x), batch_size): | |
batch_ix = indices[batch_start: batch_start + batch_size] | |
yield x[batch_ix], y[batch_ix] | |
else: | |
yield x, y | |
if not cycle: | |
break | |
# Double grad utils | |
@contextmanager | |
def replace_variables(replacement_dict, strict=True, verbose=False, canonicalize_names=True, scope='', **kwargs): | |
""" A context that replaces all newly created tf variables using a replacement_dict{name -> value} """ | |
if canonicalize_names: | |
new_replacement_dict = {canonicalize(key): var for key, var in replacement_dict.items()} | |
assert len(new_replacement_dict) == len(replacement_dict), \ | |
"multiple variables got same canonic names, output: {}".format(new_replacement_dict) | |
replacement_dict = new_replacement_dict | |
def _custom_getter(getter, name, shape, *args, **kwargs): | |
name = canonicalize(name) if canonicalize_names else name | |
assert not strict or name in replacement_dict, "variable {} not found".format(name) | |
if name in replacement_dict: | |
if verbose: | |
print("variable {} replaced with {}".format(name, replacement_dict[name])) | |
return replacement_dict[name] | |
else: | |
if verbose: | |
print("variable {} not found, creating new".format(name)) | |
return getter(name = name, shape = shape, *args, **kwargs) | |
with tf.variable_scope(scope, custom_getter=_custom_getter, **kwargs): | |
yield | |
def canonicalize(name): | |
""" canonicalize varaible name: remove empty scopes (//) and colons """ | |
if ':' in name: | |
name = name[:name.index(':')] | |
while '//' in name: | |
name = name.replace('//', '/') | |
return name | |
def get_updated_model(initial_model, make_model, get_loss, n_steps=1, learning_rate=0.01, | |
model_variables=None, **kwargs): | |
""" | |
Performs in-graph SGD steps on the model. | |
:param initial_model: initial tfnn model (a sack of variables) | |
:param make_model: a function with no inputs that creates new model | |
The new model should "live" in the same variable scope as initial_model | |
:param get_loss: a function(model) -> scalar loss to be optimized | |
:param n_steps: number of gradient descent steps | |
:param learning_rate: sgd learning rate | |
:param model_variables: a list of model variables. defaults to all trainable variables in model scope | |
""" | |
model = initial_model | |
assert hasattr(model, 'scope') or model_variables is not None, \ | |
"model has no .scope, please add it or specify model_variables=[list_of_trainable_weights]" | |
model_variables = model_variables or tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, | |
scope=model.scope.name) | |
assert isinstance(model_variables, (list, tuple)) | |
variable_names = [canonicalize(var.name) for var in model_variables] | |
for step in range(n_steps): | |
grads = tf.gradients(get_loss(model), model_variables) | |
# Perform SGD update. Note: if you use adaptive optimizer (e.g. Adam), implement it HERE | |
updated_variables = { | |
name: (var - learning_rate * grad) if grad is not None else var | |
for name, var, grad in zip(variable_names, model_variables, grads) | |
} | |
with replace_variables(updated_variables, strict=True, **kwargs): | |
model = make_model() | |
model_variables = [updated_variables[name] for name in variable_names] | |
return model |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment