Created
May 26, 2024 16:27
-
-
Save proger/879bb49fd52868bf653a8ce2e97f8114 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#%% | |
import math | |
import matplotlib.pyplot as plt | |
import torch | |
import torch.nn as nn | |
plt.rcParams['axes.spines.left'] = False | |
plt.rcParams['axes.spines.right'] = False | |
plt.rcParams['axes.spines.top'] = False | |
plt.rcParams['axes.spines.bottom'] = False | |
torch.manual_seed(3407) | |
dataset = [] | |
labels = [] | |
plt.figure(figsize=(12, 3)) | |
for label, mean in enumerate([-0.75, -0.25, 0.25, 0.75]): | |
mu = mean + torch.randn(1000) * 0.05 | |
plt.hist(mu, bins=100, alpha=0.5) | |
dataset.append(mu) | |
labels.append(label) | |
dataset = torch.cat(dataset, dim=0).unsqueeze(-1) | |
labels = torch.tensor(labels) | |
class Siren(nn.Module): | |
def __init__(self, channels=1, dim=512, bandwidth=20): | |
super().__init__() | |
self.channels = channels | |
self.input = nn.Linear(channels + 1, dim, bias=False) | |
self.hidden = nn.Linear(dim, dim, bias=False) | |
self.output = nn.Linear(dim, channels, bias=False) | |
self.bandwidth = bandwidth | |
with torch.no_grad(): | |
self.input.weight.uniform_(-1 / 2, 1 / 2) | |
l = (6/dim)**0.5 / bandwidth | |
self.hidden.weight.uniform_(-l, l) | |
self.output.weight.uniform_(-l, l) | |
def forward(self, mu, t): | |
x = self.input(torch.cat([mu, t], dim=-1)) | |
x = (self.bandwidth * x).sin() | |
x = self.hidden(x) | |
x = (self.bandwidth * x).sin() | |
x = self.output(x) | |
#x.register_hook(lambda grad: print('output grad', grad.norm())) | |
return x | |
class Sampler(nn.Module): | |
def __init__(self, channels=1, sigma=0.01, unroll_steps=10): | |
super().__init__() | |
self.state_dim = channels | |
self.sigma = nn.Parameter(torch.tensor(sigma), requires_grad=False) | |
self.h_min = -1 | |
self.h_max = 1 | |
self.unroll_steps = unroll_steps | |
self.score = Siren(channels=channels, dim=256, bandwidth=20) | |
self.step_ids = nn.Parameter(torch.arange(self.unroll_steps+1), requires_grad=False) | |
self.times = nn.Parameter(((self.step_ids / self.unroll_steps).repeat(1,1).T).clip(1e-6, 1), requires_grad=False) # T,1 | |
def step(self, state, t): | |
update = self.score(state, t) | |
gamma = 1 - self.sigma**(2*t) | |
f = 1/gamma | |
i = -((1 - gamma)/gamma + 1e-6).sqrt() | |
h = f * state + i * update | |
return torch.where(t < 1e-6, torch.zeros_like(h), h.clip(self.h_min, self.h_max)) | |
def forward(self, x_NC): | |
x_NTC = x_NC.unsqueeze(1) # N,T,C | |
#t = self.times.T.unsqueeze(-1).repeat(x_NTC.shape[0], 1, 1) # pretend timesteps are fixed | |
t = torch.rand(x_NC.shape[0], self.unroll_steps, 1, device=x_NC.device).clip(1e-6, 1) | |
gamma_NT1 = 1 - self.sigma**(2*t) | |
std_NT1 = (gamma_NT1 * (1 - gamma_NT1) + 1e-6).sqrt() | |
mu_NTC = gamma_NT1 * x_NTC + torch.randn_like(x_NTC) * std_NT1 | |
x1_NTC = self.step(mu_NTC, t) | |
scale_NT1 = math.log(self.sigma) / self.sigma**(2*t) | |
diff_NTC = x1_NTC - x_NTC | |
norm_NT = (diff_NTC.square().sum(-1) + 1e-6).sqrt() | |
mse_NT = -norm_NT * scale_NT1.squeeze(-1) | |
return mse_NT | |
def normal_update(self, prior_precision, state, likelihood_precision, obs): | |
return (prior_precision * state + likelihood_precision * obs) / (prior_precision + likelihood_precision) | |
def generate(self, batch_size, ax=None): | |
device = next(self.parameters()).device | |
state = torch.zeros(batch_size, self.unroll_steps+1, self.state_dim, device=device) | |
t = self.times | |
ids = self.step_ids | |
T = self.unroll_steps | |
likelihood_precision = self.sigma ** (-2 * (ids + 1) / T) * (1 - self.sigma.pow(2/T)) | |
prior_precision = torch.cat([torch.ones(1, device=device), likelihood_precision.cumsum(0)], dim=0) | |
out = self.step(state[:, 0], t[None, 0].repeat(batch_size, 1)) | |
for step in range(1, T+1): | |
y = out + torch.randn_like(out) / likelihood_precision[step-1].sqrt() | |
state[:, step] = self.normal_update(prior_precision[step-1], state[:, step-1], likelihood_precision[step-1], y) | |
out = self.step(state[:, step], t[None, step].repeat(batch_size, 1)) | |
if ax is not None: | |
for i in range(batch_size): | |
ax.plot(state[i, :, 0].detach().numpy(), alpha=0.3) | |
return out | |
torch.manual_seed(6) | |
torch.set_anomaly_enabled(True) | |
flow = Sampler(channels=dataset.size(1), sigma=0.01).to('cuda') | |
dataset = dataset.to('cuda') | |
opt = torch.optim.Adam(flow.parameters(), lr=1e-3) | |
train_steps = 200 | |
trace_loss = torch.zeros(train_steps) | |
trace_gnorm = torch.zeros(train_steps) | |
for i in range(train_steps): | |
opt.zero_grad() | |
minibatch = torch.arange(len(dataset)) # full batch training | |
x = dataset[minibatch] | |
losses = flow(x) | |
loss = losses.mean() | |
assert not torch.isnan(loss), f'loss is nan at step {i}' | |
loss.backward() | |
trace_loss[i] = loss.item() | |
trace_gnorm[i] = torch.nn.utils.clip_grad_norm_(flow.parameters(), 1.0) | |
opt.step() | |
fig, (axl, axc, axr, axf) = plt.subplots(1, 4, figsize=(20, 3)) | |
axl.plot(trace_loss) | |
axl.set_title('loss') | |
axc.plot(trace_gnorm) | |
axc.set_title('gradient norms') | |
axl.set_xlim(0, train_steps) | |
axc.set_xlim(0, train_steps) | |
flow = flow.to('cpu') | |
with torch.no_grad(): | |
gen = flow.generate(batch_size=1000, ax=axf) | |
axr.set_title('sampled data') | |
axf.set_title('sample flows') | |
axf.set_ylim(-1, 1) | |
for i in range(dataset.size(1)): | |
axr.hist(gen[:, i].numpy(), bins=100, alpha=0.5); |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment