Created
April 3, 2020 16:29
-
-
Save braun-steven/ceb899a64630cb1473e84986b0bfb3b5 to your computer and use it in GitHub Desktop.
Layer-wise SPN Example Usage: Forward, Sampling, Conditional Sampling
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import matplotlib.pyplot as plt | |
import numpy as np | |
import torch | |
from sklearn.datasets import make_blobs | |
from torch import nn | |
from tqdm import trange | |
from spn.algorithms.layerwise.layers import Product, Sum | |
from spn.algorithms.layerwise.utils import provide_evidence | |
from spn.experiments.RandomSPNs_layerwise.distributions import RatNormal | |
if __name__ == "__main__": | |
class LayerSpn(nn.Module): | |
def __init__(self): | |
super().__init__() | |
# Normal leaf layer, output shape: [N=?, D=2, C=5, R=1] | |
self.leaf = RatNormal(in_features=2, out_channels=2) | |
# Product layer, output shape: [N=?, D=1, C=5, R=1] | |
self.p = Product(in_features=2, cardinality=2) | |
# Sum layer, root node, output shape: [N=?, D=1, C=1, R=1] | |
self.s = Sum(in_channels=2, in_features=1, out_channels=1) | |
def forward(self, x): | |
# Forward bottom up | |
x = self.leaf(x) | |
x = self.p(x) | |
x = self.s(x) | |
return x | |
def sample(self, n=100): | |
# Sample top down | |
ctx = self.s.sample(n=n) | |
ctx = self.p.sample(context=ctx) | |
samples = self.leaf.sample(context=ctx) | |
return samples | |
# Generate two gaussian blobs | |
n_labels = 2 | |
n_samples = 500 | |
data, y = make_blobs( | |
n_samples=n_samples, centers=n_labels, n_features=2, random_state=0, center_box=(-15, 15), cluster_std=0.5 | |
) | |
data = torch.from_numpy(data).float() | |
# Plot the original data | |
plt.figure() | |
plt.subplot(2, 2, 1) | |
plt.title("Original training data") | |
for i in range(n_labels): | |
plt.scatter(*data[y == i].T, label=f"Blob {i}", alpha=0.7) | |
plt.legend() | |
plt.xlabel("$x_0$") | |
plt.ylabel("$x_1$") | |
xlim, ylim = plt.xlim(), plt.ylim() | |
# Create SPN model | |
spn = LayerSpn() | |
# Use SGD | |
optimizer = torch.optim.SGD(spn.parameters(), lr=0.5, weight_decay=0.0) | |
batch_size = 250 | |
n_epochs = 1000 | |
with trange(n_epochs) as epoch_iter: | |
for epoch in epoch_iter: | |
running_loss = 0.0 | |
for batch_idx in np.arange(data.shape[0], step=batch_size): | |
batch = data[batch_idx : batch_idx + batch_size] | |
# Reset gradients | |
optimizer.zero_grad() | |
# Inference | |
output = spn(batch) | |
# Comput loss | |
loss = -1 * output.mean() | |
# Backprop | |
loss.backward() | |
optimizer.step() | |
# Collect loss | |
running_loss += loss.item() | |
epoch_iter.set_description(f"Loss: {running_loss/(data.shape[0] // batch_size):<3.4f}") | |
with torch.no_grad(): | |
# Sample unconditionally | |
samples = spn.sample(n=1000) | |
plt.subplot(2, 2, 2) | |
plt.xlabel("$x_0$") | |
plt.ylabel("$x_1$") | |
plt.title("Unconditioned Samples") | |
plt.scatter(*samples.T, alpha=0.7, c="black") | |
plt.xlim(xlim) | |
plt.ylim(ylim) | |
# Sample, conditioned on x_0 | |
plt.subplot(2, 2, 3) | |
plt.xlabel("$x_0$") | |
plt.ylabel("$x_1$") | |
plt.title("Conditioned Samples (on $x_0$)") | |
for i in range(n_labels): | |
data_i = data[y == i] | |
data_i[:, 1] = float("nan") | |
with provide_evidence(spn, data_i): | |
samples = spn.sample(n=data_i.shape[0]) | |
plt.scatter(*samples.T, label=f"Blob {i}", alpha=0.7) | |
plt.xlim(xlim) | |
plt.ylim(ylim) | |
plt.legend() | |
# Sample, conditioned on x_1 | |
plt.subplot(2, 2, 4) | |
plt.xlabel("$x_0$") | |
plt.ylabel("$x_1$") | |
plt.title("Conditioned Samples (on $x_1$)") | |
for i in range(n_labels): | |
data_i = data[y == i] | |
data_i[:, 0] = float("nan") | |
with provide_evidence(spn, data_i): | |
samples = spn.sample(n=data_i.shape[0]) | |
plt.scatter(*samples.T, label=f"Blob {i}", alpha=0.7) | |
plt.xlim(xlim) | |
plt.ylim(ylim) | |
plt.legend() | |
plt.tight_layout() | |
plt.savefig("sampling-result.png", dpi=180) |
Author
braun-steven
commented
Apr 3, 2020
•
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment