Skip to content

Instantly share code, notes, and snippets.

@redwrasse
Last active February 21, 2020 19:38
Show Gist options
  • Save redwrasse/b2c985f5be47d0ba212419615f2bc029 to your computer and use it in GitHub Desktop.
Save redwrasse/b2c985f5be47d0ba212419615f2bc029 to your computer and use it in GitHub Desktop.
import torch
import itertools
"""
an example training an undirected graphical model
over a small number of binary variables, brute force
computing the partition function.
The model has a pairwise energy function E = w_ij x_i x_j + a_i x_i
Three examples:
i) compute gradients in pytorch
-grad (log p) = - grad( log Z - sum_i(E_i))
ii) compute gradients analytically through moments
-grad_w_ij (log p) = <x_i x_j>_data - <x_i x_j>_model
-grad_a_i (log p) = <x_i>_data - <x_i>_model
iii) imposing L^2 lagrangian constraints on the loss expressing
a particular graph topology, for example a linear chain
loss = -log p + lambda * sum_{(i,j)} w_ij ^ 2
"""
DIM = 3 # number of binary variables
NUMBER_OF_SAMPLES = 1000
def partition_function(w, a):
z = torch.tensor(0.)
d = w.size()[0]
lists = [[-1., 1.] for _ in range(d)]
for element in itertools.product(*lists):
x = torch.tensor(element)
term = torch.exp(-energy(x, w, a))
z += term
return z
# <x_i>_data
def data_first_moment(x):
n, d = x.size()
qe = torch.zeros((d,))
for element in x:
for i in range(d):
qe[i] += element[i]
return qe / n
# <x_i>_model
def model_first_moment(w, a):
z = partition_function(w, a)
d = w.size()[0]
qe = torch.zeros((d,))
lists = [[-1., 1.] for _ in range(d)]
for element in itertools.product(*lists):
x = torch.tensor(element)
exp_term = torch.exp(-energy(x, w, a))
for i in range(d):
qe[i] += x[i] * exp_term
qe /= z
return qe
# <x_i x_j>_data
def data_second_moment(x):
n, d = x.size()
qe = torch.zeros((d, d))
for element in x:
for i in range(d):
for j in range(d):
qe[i, j] += element[i] * element[j]
return qe / n
# <x_i x_j>_model
def model_second_moment(w, a):
z = partition_function(w, a)
d = w.size()[0]
qe = torch.zeros((d, d))
lists = [[-1., 1.] for _ in range(d)]
for element in itertools.product(*lists):
x = torch.tensor(element)
exp_term = torch.exp(-energy(x, w, a))
for i in range(d):
for j in range(d):
qe[i, j] += x[i] * x[j] * exp_term
qe /= z
return qe
def negative_log_prob(x, w, a):
return torch.log(partition_function(w, a)) \
+ torch.mean(torch.stack([energy(x[i], w, a) for i in range(x.size()[0])]), dim=0)
def energy(x, w, a):
return torch.dot(x, torch.matmul(w, x)) + torch.dot(a, x)
# <xi xj>_data - <xi xj>_model
def w_grad(x, w, a):
return data_second_moment(x) - model_second_moment(w, a)
# <xi>_data - <x_i>_model
def a_grad(x, w, a):
return data_first_moment(x) - model_first_moment(w, a)
def make_sample_data():
"""
95% deterministic + 5% random data in {-1, 1}^N
"""
k = int(NUMBER_OF_SAMPLES * 0.95)
deterministic = torch.FloatTensor(1, DIM).random_(2) * 2 - 1
noise = torch.FloatTensor(NUMBER_OF_SAMPLES - k, DIM).random_(2) * 2 - 1
z = torch.squeeze(torch.stack([deterministic] * k))
return torch.cat((z, noise), 0)
def train_with_analytic_grad():
eps = 0.001
n_epochs = 10**5
batch_size = 50
x0 = make_sample_data()
w = torch.randn((DIM, DIM), requires_grad=True)
a = torch.randn(DIM, requires_grad=True)
for i in range(n_epochs):
perm = torch.randperm(NUMBER_OF_SAMPLES)
for j in range(0, NUMBER_OF_SAMPLES, batch_size):
indices = perm[j:j + batch_size]
batch_x = x0[indices]
neg_lp = negative_log_prob(batch_x, w, a)
with torch.no_grad():
w = w - eps * w_grad(batch_x, w, a)
a = a - eps * a_grad(batch_x, w, a)
if i % 2 == 0:
print(f'negative log prob: {neg_lp}')
def train_with_grad_implemented():
eps = 0.001
n_epochs = 10**5
batch_size = 50
x0 = make_sample_data()
w = torch.randn((DIM, DIM), requires_grad=True)
a = torch.randn(DIM, requires_grad=True)
params = [w, a]
for i in range(n_epochs):
perm = torch.randperm(NUMBER_OF_SAMPLES)
for j in range(0, NUMBER_OF_SAMPLES, batch_size):
indices = perm[j:j + batch_size]
batch_x = x0[indices]
neg_lp = negative_log_prob(batch_x, w, a)
for p in params:
if p.grad is not None:
p.grad.data.zero_()
neg_lp.backward()
with torch.no_grad():
w = w - eps * w.grad
a = a - eps * a.grad
w.requires_grad = True
a.requires_grad = True
if i % 2 == 0:
print(f'negative log prob: {neg_lp}')
def linear_chain_constraint_indices():
# excluded indices for a linear chain
# must be list of form [(x_indices), (y_indices)]
pairs = [(i, j) for i in range(DIM) for j in range(DIM)
if not (i == j + 1 or i == j - 1)]
return [tuple(p[0] for p in pairs), tuple(p[1] for p in pairs)]
def train_with_constraints(constraint_indices):
eps = 0.001
n_epochs = 10**5
batch_size = 50
x0 = make_sample_data()
w = torch.randn((DIM, DIM), requires_grad=True)
a = torch.randn(DIM, requires_grad=True)
lambda_multiplier = 3.
params = [w, a]
for i in range(n_epochs):
perm = torch.randperm(NUMBER_OF_SAMPLES)
for j in range(0, NUMBER_OF_SAMPLES, batch_size):
indices = perm[j:j + batch_size]
batch_x = x0[indices]
lagrangian_constraint = lambda_multiplier * \
torch.sum(w[constraint_indices] ** 2)
neg_lp = negative_log_prob(batch_x, w, a) + lagrangian_constraint
for p in params:
if p.grad is not None:
p.grad.data.zero_()
neg_lp.backward()
with torch.no_grad():
w = w - eps * w.grad
a = a - eps * a.grad
w.requires_grad = True
a.requires_grad = True
if i % 20 == 0:
print(f'negative log prob: {neg_lp}')
if __name__ == "__main__":
#train_with_grad_implemented()
#train_with_analytic_grad()
train_with_constraints(linear_chain_constraint_indices())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment