Last active
May 18, 2022 05:32
-
-
Save gngdb/a9f912df362a85b37c730154ef3c294b to your computer and use it in GitHub Desktop.
Wrap PyTorch functions for scipy's optimize.minimize: https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.minimize.html (I also made a repo to do this https://github.com/gngdb/pytorch-minimize, although I had forgotten about this gist at the time)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
import torch.nn.functional as F | |
import numpy as np | |
from scipy import optimize | |
from obj import PyTorchObjective | |
from tqdm import tqdm | |
if __name__ == '__main__': | |
# whatever this initialises to is our "true" W | |
linear = nn.Linear(32,32) | |
linear = linear.eval() | |
# input X | |
N = 10000 | |
X = torch.Tensor(N,32) | |
X.uniform_(0.,1.) # fill with uniform | |
eps = torch.Tensor(N,32) | |
eps.normal_(0., 1e-4) | |
# output Y | |
with torch.no_grad(): | |
Y = linear(X) #+ eps | |
# make module executing the experiment | |
class Objective(nn.Module): | |
def __init__(self): | |
super(Objective, self).__init__() | |
self.linear = nn.Linear(32,32) | |
self.linear = self.linear.train() | |
self.X, self.Y = X, Y | |
def forward(self): | |
output = self.linear(self.X) | |
return F.mse_loss(output, self.Y).mean() | |
objective = Objective() | |
maxiter = 100 | |
with tqdm(total=maxiter) as pbar: | |
def verbose(xk): | |
pbar.update(1) | |
# try to optimize that function with scipy | |
obj = PyTorchObjective(objective) | |
xL = optimize.minimize(obj.fun, obj.x0, method='BFGS', jac=obj.jac, | |
callback=verbose, options={'gtol': 1e-6, 'disp': True, | |
'maxiter':maxiter}) | |
#xL = optimize.minimize(obj.fun, obj.x0, method='CG', jac=obj.jac)# , options={'gtol': 1e-2}) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from scipy import optimize | |
import torch.nn.functional as F | |
import math | |
import numpy as np | |
from functools import reduce | |
from collections import OrderedDict | |
class PyTorchObjective(object): | |
"""PyTorch objective function, wrapped to be called by scipy.optimize.""" | |
def __init__(self, obj_module): | |
self.f = obj_module # some pytorch module, that produces a scalar loss | |
# make an x0 from the parameters in this module | |
parameters = OrderedDict(obj_module.named_parameters()) | |
self.param_shapes = {n:parameters[n].size() for n in parameters} | |
# ravel and concatenate all parameters to make x0 | |
self.x0 = np.concatenate([parameters[n].data.numpy().ravel() | |
for n in parameters]) | |
def unpack_parameters(self, x): | |
"""optimize.minimize will supply 1D array, chop it up for each parameter.""" | |
i = 0 | |
named_parameters = OrderedDict() | |
for n in self.param_shapes: | |
param_len = reduce(lambda x,y: x*y, self.param_shapes[n]) | |
# slice out a section of this length | |
param = x[i:i+param_len] | |
# reshape according to this size, and cast to torch | |
param = param.reshape(*self.param_shapes[n]) | |
named_parameters[n] = torch.from_numpy(param) | |
# update index | |
i += param_len | |
return named_parameters | |
def pack_grads(self): | |
"""pack all the gradients from the parameters in the module into a | |
numpy array.""" | |
grads = [] | |
for p in self.f.parameters(): | |
grad = p.grad.data.numpy() | |
grads.append(grad.ravel()) | |
return np.concatenate(grads) | |
def is_new(self, x): | |
# if this is the first thing we've seen | |
if not hasattr(self, 'cached_x'): | |
return True | |
else: | |
# compare x to cached_x to determine if we've been given a new input | |
x, self.cached_x = np.array(x), np.array(self.cached_x) | |
error = np.abs(x - self.cached_x) | |
return error.max() > 1e-8 | |
def cache(self, x): | |
# unpack x and load into module | |
state_dict = self.unpack_parameters(x) | |
self.f.load_state_dict(state_dict) | |
# store the raw array as well | |
self.cached_x = x | |
# zero the gradient | |
self.f.zero_grad() | |
# use it to calculate the objective | |
obj = self.f() | |
# backprop the objective | |
obj.backward() | |
self.cached_f = obj.item() | |
self.cached_jac = self.pack_grads() | |
def fun(self, x): | |
if self.is_new(x): | |
self.cache(x) | |
return self.cached_f | |
def jac(self, x): | |
if self.is_new(x): | |
self.cache(x) | |
return self.cached_jac |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment