Last active
February 21, 2020 19:38
-
-
Save redwrasse/b2c985f5be47d0ba212419615f2bc029 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import itertools | |
""" | |
an example training an undirected graphical model | |
over a small number of binary variables, brute force | |
computing the partition function. | |
The model has a pairwise energy function E = w_ij x_i x_j + a_i x_i | |
Three examples: | |
i) compute gradients in pytorch | |
-grad (log p) = - grad( log Z - sum_i(E_i)) | |
ii) compute gradients analytically through moments | |
-grad_w_ij (log p) = <x_i x_j>_data - <x_i x_j>_model | |
-grad_a_i (log p) = <x_i>_data - <x_i>_model | |
iii) imposing L^2 lagrangian constraints on the loss expressing | |
a particular graph topology, for example a linear chain | |
loss = -log p + lambda * sum_{(i,j)} w_ij ^ 2 | |
""" | |
DIM = 3 # number of binary variables | |
NUMBER_OF_SAMPLES = 1000 | |
def partition_function(w, a): | |
z = torch.tensor(0.) | |
d = w.size()[0] | |
lists = [[-1., 1.] for _ in range(d)] | |
for element in itertools.product(*lists): | |
x = torch.tensor(element) | |
term = torch.exp(-energy(x, w, a)) | |
z += term | |
return z | |
# <x_i>_data | |
def data_first_moment(x): | |
n, d = x.size() | |
qe = torch.zeros((d,)) | |
for element in x: | |
for i in range(d): | |
qe[i] += element[i] | |
return qe / n | |
# <x_i>_model | |
def model_first_moment(w, a): | |
z = partition_function(w, a) | |
d = w.size()[0] | |
qe = torch.zeros((d,)) | |
lists = [[-1., 1.] for _ in range(d)] | |
for element in itertools.product(*lists): | |
x = torch.tensor(element) | |
exp_term = torch.exp(-energy(x, w, a)) | |
for i in range(d): | |
qe[i] += x[i] * exp_term | |
qe /= z | |
return qe | |
# <x_i x_j>_data | |
def data_second_moment(x): | |
n, d = x.size() | |
qe = torch.zeros((d, d)) | |
for element in x: | |
for i in range(d): | |
for j in range(d): | |
qe[i, j] += element[i] * element[j] | |
return qe / n | |
# <x_i x_j>_model | |
def model_second_moment(w, a): | |
z = partition_function(w, a) | |
d = w.size()[0] | |
qe = torch.zeros((d, d)) | |
lists = [[-1., 1.] for _ in range(d)] | |
for element in itertools.product(*lists): | |
x = torch.tensor(element) | |
exp_term = torch.exp(-energy(x, w, a)) | |
for i in range(d): | |
for j in range(d): | |
qe[i, j] += x[i] * x[j] * exp_term | |
qe /= z | |
return qe | |
def negative_log_prob(x, w, a): | |
return torch.log(partition_function(w, a)) \ | |
+ torch.mean(torch.stack([energy(x[i], w, a) for i in range(x.size()[0])]), dim=0) | |
def energy(x, w, a): | |
return torch.dot(x, torch.matmul(w, x)) + torch.dot(a, x) | |
# <xi xj>_data - <xi xj>_model | |
def w_grad(x, w, a): | |
return data_second_moment(x) - model_second_moment(w, a) | |
# <xi>_data - <x_i>_model | |
def a_grad(x, w, a): | |
return data_first_moment(x) - model_first_moment(w, a) | |
def make_sample_data(): | |
""" | |
95% deterministic + 5% random data in {-1, 1}^N | |
""" | |
k = int(NUMBER_OF_SAMPLES * 0.95) | |
deterministic = torch.FloatTensor(1, DIM).random_(2) * 2 - 1 | |
noise = torch.FloatTensor(NUMBER_OF_SAMPLES - k, DIM).random_(2) * 2 - 1 | |
z = torch.squeeze(torch.stack([deterministic] * k)) | |
return torch.cat((z, noise), 0) | |
def train_with_analytic_grad(): | |
eps = 0.001 | |
n_epochs = 10**5 | |
batch_size = 50 | |
x0 = make_sample_data() | |
w = torch.randn((DIM, DIM), requires_grad=True) | |
a = torch.randn(DIM, requires_grad=True) | |
for i in range(n_epochs): | |
perm = torch.randperm(NUMBER_OF_SAMPLES) | |
for j in range(0, NUMBER_OF_SAMPLES, batch_size): | |
indices = perm[j:j + batch_size] | |
batch_x = x0[indices] | |
neg_lp = negative_log_prob(batch_x, w, a) | |
with torch.no_grad(): | |
w = w - eps * w_grad(batch_x, w, a) | |
a = a - eps * a_grad(batch_x, w, a) | |
if i % 2 == 0: | |
print(f'negative log prob: {neg_lp}') | |
def train_with_grad_implemented(): | |
eps = 0.001 | |
n_epochs = 10**5 | |
batch_size = 50 | |
x0 = make_sample_data() | |
w = torch.randn((DIM, DIM), requires_grad=True) | |
a = torch.randn(DIM, requires_grad=True) | |
params = [w, a] | |
for i in range(n_epochs): | |
perm = torch.randperm(NUMBER_OF_SAMPLES) | |
for j in range(0, NUMBER_OF_SAMPLES, batch_size): | |
indices = perm[j:j + batch_size] | |
batch_x = x0[indices] | |
neg_lp = negative_log_prob(batch_x, w, a) | |
for p in params: | |
if p.grad is not None: | |
p.grad.data.zero_() | |
neg_lp.backward() | |
with torch.no_grad(): | |
w = w - eps * w.grad | |
a = a - eps * a.grad | |
w.requires_grad = True | |
a.requires_grad = True | |
if i % 2 == 0: | |
print(f'negative log prob: {neg_lp}') | |
def linear_chain_constraint_indices(): | |
# excluded indices for a linear chain | |
# must be list of form [(x_indices), (y_indices)] | |
pairs = [(i, j) for i in range(DIM) for j in range(DIM) | |
if not (i == j + 1 or i == j - 1)] | |
return [tuple(p[0] for p in pairs), tuple(p[1] for p in pairs)] | |
def train_with_constraints(constraint_indices): | |
eps = 0.001 | |
n_epochs = 10**5 | |
batch_size = 50 | |
x0 = make_sample_data() | |
w = torch.randn((DIM, DIM), requires_grad=True) | |
a = torch.randn(DIM, requires_grad=True) | |
lambda_multiplier = 3. | |
params = [w, a] | |
for i in range(n_epochs): | |
perm = torch.randperm(NUMBER_OF_SAMPLES) | |
for j in range(0, NUMBER_OF_SAMPLES, batch_size): | |
indices = perm[j:j + batch_size] | |
batch_x = x0[indices] | |
lagrangian_constraint = lambda_multiplier * \ | |
torch.sum(w[constraint_indices] ** 2) | |
neg_lp = negative_log_prob(batch_x, w, a) + lagrangian_constraint | |
for p in params: | |
if p.grad is not None: | |
p.grad.data.zero_() | |
neg_lp.backward() | |
with torch.no_grad(): | |
w = w - eps * w.grad | |
a = a - eps * a.grad | |
w.requires_grad = True | |
a.requires_grad = True | |
if i % 20 == 0: | |
print(f'negative log prob: {neg_lp}') | |
if __name__ == "__main__": | |
#train_with_grad_implemented() | |
#train_with_analytic_grad() | |
train_with_constraints(linear_chain_constraint_indices()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment