Created
October 19, 2020 12:22
-
-
Save vene/5fb0b44166001c8e8a8cc575a8949754 to your computer and use it in GitHub Desktop.
energy model with langevin dynamics
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # Density estimation with energy-based models | |
| # Langevin sampling, contrastive divergence training. | |
| # Author: Vlad Niculae <vlad@vene.ro> | |
| # License: MIT | |
| import numpy as np | |
| import torch | |
| from sklearn import datasets | |
| import matplotlib.pyplot as plt | |
| N = 100 | |
| xmin, xmax = -2, 3 | |
| ymin, ymax = -2, 2 | |
| grid_x = np.linspace(xmin, xmax, N) | |
| grid_y = np.linspace(ymin, ymax, N) | |
| mesh_x, mesh_y = np.meshgrid(grid_x, grid_y) | |
| grid_pts = np.column_stack([mesh_x.ravel(), mesh_y.ravel()]) | |
| grid_pts = torch.from_numpy(grid_pts).to(dtype=torch.float32) | |
| class EnergyModel(torch.nn.Sequential): | |
| def __init__(self, input_dim=2, hid_dim=64): | |
| self.input_dim = input_dim | |
| self.hid_dim = hid_dim | |
| super().__init__( | |
| torch.nn.Linear(input_dim, hid_dim), | |
| torch.nn.ReLU(), | |
| torch.nn.Dropout(.2), | |
| torch.nn.Linear(hid_dim, hid_dim), | |
| torch.nn.ReLU(), | |
| torch.nn.Dropout(.2), | |
| torch.nn.Linear(hid_dim, 1)) | |
| def plot(self, ax): | |
| self.eval() | |
| # z = torch.exp(-self(grid_pts)) | |
| z = -self(grid_pts) | |
| Z = z.reshape(N, N).detach().numpy() | |
| print(Z.min(), Z.max()) | |
| ax.contourf(mesh_x, mesh_y, Z, levels=30) | |
| def langevin_sample(self, seed=None, n_samples=100, n_iter=200): | |
| if seed is not None: | |
| X = seed.clone().requires_grad_() | |
| else: | |
| X = torch.randn(n_samples, self.input_dim).requires_grad_() | |
| for param in self.parameters(): | |
| param.requires_grad = False | |
| self.eval() | |
| for t in range(n_iter): | |
| z = self(X).sum().backward() | |
| eps = .1 * (10 + t) ** -1 | |
| sigma = np.sqrt(eps) | |
| eta = sigma * torch.randn(n_samples, self.input_dim) | |
| X.data -= (eps / 2) * X.grad - eta | |
| X.grad.zero_() | |
| for param in self.parameters(): | |
| param.requires_grad = True | |
| return X.detach() | |
| def loss(self, X): | |
| seed = X | |
| X_spl = self.langevin_sample(seed=seed, n_samples=X.shape[0]) | |
| self.train() | |
| return torch.mean(self(X) - self(X_spl)) | |
| def fit(self, X, n_iter=500, callback=None): | |
| opt = torch.optim.AdamW(lr=0.01, weight_decay=.01, params=self.parameters()) | |
| self.train() | |
| for t in range(n_iter): | |
| opt.zero_grad() | |
| lval = self.loss(X) | |
| lval.backward() | |
| opt.step() | |
| if callback: | |
| callback(self, t) | |
| def main(): | |
| # X, y = datasets.make_moons(n_samples=500, noise=.05) | |
| X, y = datasets.make_blobs(n_samples=500, | |
| centers=[[-1, -1], [-1, 1], [1, -1]], cluster_std=.1) | |
| # X, y = datasets.make_circles(n_samples=500, noise=.05, factor=.1) | |
| def callback(em, t): | |
| if t % 10 != 0: | |
| return | |
| fig = plt.figure() | |
| ax = plt.gca() | |
| em.plot(ax) | |
| plt.scatter(X[:, 0], X[:, 1], color='C3', marker='x', alpha=.2, label="data") | |
| X_samp = em.langevin_sample(n_samples=10, n_iter=200) | |
| plt.scatter(X_samp[:, 0], X_samp[:, 1], color='C5', marker=".", s=8, label="samples") | |
| plt.xlim(xmin, xmax) | |
| plt.ylim(ymin, ymax) | |
| plt.legend() | |
| plt.title(f"Iteration {t:03d}") | |
| plt.savefig(f'imgs/{t:04d}.png') | |
| plt.close(fig) | |
| X = torch.from_numpy(X).to(dtype=torch.float32) | |
| em = EnergyModel() | |
| em.fit(X, callback=callback) | |
| if __name__ == '__main__': | |
| main() | |
Author
vene
commented
Oct 19, 2020



Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment