Skip to content

Instantly share code, notes, and snippets.

@vene
Created October 19, 2020 12:22
Show Gist options
  • Select an option

  • Save vene/5fb0b44166001c8e8a8cc575a8949754 to your computer and use it in GitHub Desktop.

Select an option

Save vene/5fb0b44166001c8e8a8cc575a8949754 to your computer and use it in GitHub Desktop.
energy model with langevin dynamics
# Density estimation with energy-based models
# Langevin sampling, contrastive divergence training.
# Author: Vlad Niculae <vlad@vene.ro>
# License: MIT
import numpy as np
import torch
from sklearn import datasets
import matplotlib.pyplot as plt
N = 100
xmin, xmax = -2, 3
ymin, ymax = -2, 2
grid_x = np.linspace(xmin, xmax, N)
grid_y = np.linspace(ymin, ymax, N)
mesh_x, mesh_y = np.meshgrid(grid_x, grid_y)
grid_pts = np.column_stack([mesh_x.ravel(), mesh_y.ravel()])
grid_pts = torch.from_numpy(grid_pts).to(dtype=torch.float32)
class EnergyModel(torch.nn.Sequential):
def __init__(self, input_dim=2, hid_dim=64):
self.input_dim = input_dim
self.hid_dim = hid_dim
super().__init__(
torch.nn.Linear(input_dim, hid_dim),
torch.nn.ReLU(),
torch.nn.Dropout(.2),
torch.nn.Linear(hid_dim, hid_dim),
torch.nn.ReLU(),
torch.nn.Dropout(.2),
torch.nn.Linear(hid_dim, 1))
def plot(self, ax):
self.eval()
# z = torch.exp(-self(grid_pts))
z = -self(grid_pts)
Z = z.reshape(N, N).detach().numpy()
print(Z.min(), Z.max())
ax.contourf(mesh_x, mesh_y, Z, levels=30)
def langevin_sample(self, seed=None, n_samples=100, n_iter=200):
if seed is not None:
X = seed.clone().requires_grad_()
else:
X = torch.randn(n_samples, self.input_dim).requires_grad_()
for param in self.parameters():
param.requires_grad = False
self.eval()
for t in range(n_iter):
z = self(X).sum().backward()
eps = .1 * (10 + t) ** -1
sigma = np.sqrt(eps)
eta = sigma * torch.randn(n_samples, self.input_dim)
X.data -= (eps / 2) * X.grad - eta
X.grad.zero_()
for param in self.parameters():
param.requires_grad = True
return X.detach()
def loss(self, X):
seed = X
X_spl = self.langevin_sample(seed=seed, n_samples=X.shape[0])
self.train()
return torch.mean(self(X) - self(X_spl))
def fit(self, X, n_iter=500, callback=None):
opt = torch.optim.AdamW(lr=0.01, weight_decay=.01, params=self.parameters())
self.train()
for t in range(n_iter):
opt.zero_grad()
lval = self.loss(X)
lval.backward()
opt.step()
if callback:
callback(self, t)
def main():
# X, y = datasets.make_moons(n_samples=500, noise=.05)
X, y = datasets.make_blobs(n_samples=500,
centers=[[-1, -1], [-1, 1], [1, -1]], cluster_std=.1)
# X, y = datasets.make_circles(n_samples=500, noise=.05, factor=.1)
def callback(em, t):
if t % 10 != 0:
return
fig = plt.figure()
ax = plt.gca()
em.plot(ax)
plt.scatter(X[:, 0], X[:, 1], color='C3', marker='x', alpha=.2, label="data")
X_samp = em.langevin_sample(n_samples=10, n_iter=200)
plt.scatter(X_samp[:, 0], X_samp[:, 1], color='C5', marker=".", s=8, label="samples")
plt.xlim(xmin, xmax)
plt.ylim(ymin, ymax)
plt.legend()
plt.title(f"Iteration {t:03d}")
plt.savefig(f'imgs/{t:04d}.png')
plt.close(fig)
X = torch.from_numpy(X).to(dtype=torch.float32)
em = EnergyModel()
em.fit(X, callback=callback)
if __name__ == '__main__':
main()
@vene
Copy link
Copy Markdown
Author

vene commented Oct 19, 2020

blobs
moons
rings

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment