Created
September 10, 2019 14:06
-
-
Save vene/53f242c058d71aed9d9b38c67596046c to your computer and use it in GitHub Desktop.
Generate samples from a discrete latent variable classification model
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| Generate data triples (x, y, z) for deterministic classification p(y | x; z) | |
| Generative story: | |
| Given: n_clusters; for each cluster: | |
| - a cluster center (mean) center[z] | |
| - a linear model y=sign(w[z] * x + b[z]) | |
| pick z from uniform Categorical(n_clusters) | |
| pick cluster center c = center[z] | |
| sample x from N(mu[z], sigma * I) | |
| generate deterministic y = sign(w[z] * x + b[z]) | |
| """ | |
| # author: vlad niculae <vlad@vene.ro> | |
| # license: bsd 2 clause | |
| import torch | |
| def make_latent_triples(n_samples, n_features, n_clusters, data_std=.1, | |
| cluster_std=1): | |
| # generate cluster centers | |
| centers = cluster_std * torch.randn(n_clusters, n_features) | |
| # generate a linear model for each cluster | |
| W = torch.randn(n_clusters, n_features) | |
| # draw cluster assignments | |
| z = torch.randint(low=0, high=n_clusters, size=(n_samples,)) | |
| # draw data X | |
| c_ = centers[z] | |
| X = c_ + data_std * torch.randn(n_samples, n_features) | |
| # choose linear model to use for each sample | |
| W_ = W[z] | |
| # compute true label y | |
| y_score = (W_ * X).sum(dim=-1) | |
| # pick a threshold for each class | |
| # (note: this is done like this to ensure there are always roughly balanced | |
| # positive and negative samples in each class) | |
| b = torch.zeros(n_clusters) | |
| for c in range(n_clusters): | |
| b[c] = y_score[z == c].mean() | |
| y = torch.sign(y_score - b[z]) | |
| return X, y, z | |
| def main(): | |
| torch.manual_seed(41) | |
| n_samples = 100 | |
| n_features = 2 | |
| n_clusters = 4 | |
| X, y, z = make_latent_triples(n_samples, n_features, n_clusters) | |
| import matplotlib.pyplot as plt | |
| Xp, zp = X[y > 0], z[y > 0] | |
| Xn, zn = X[y < 0], z[y < 0] | |
| print(zp) | |
| print(zn) | |
| plt.scatter(Xp[:, 0], Xp[:, 1], c=zp, marker='+') | |
| plt.scatter(Xn[:, 0], Xn[:, 1], c=zn, marker='.') | |
| plt.show() | |
| if __name__ == '__main__': | |
| main() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment