Skip to content

Instantly share code, notes, and snippets.

@pengzhangzhi
Created March 11, 2025 22:37
Show Gist options
  • Save pengzhangzhi/4badb0f7d219ee9016f08d60c1da2b10 to your computer and use it in GitHub Desktop.
Save pengzhangzhi/4badb0f7d219ee9016f08d60c1da2b10 to your computer and use it in GitHub Desktop.
standard alone discrete diffusion model example on 2D data.
import torch
import matplotlib.pyplot as plt
from torch import nn, Tensor
from sklearn.datasets import make_moons
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class DiscreteFlow(nn.Module):
def __init__(self, dim: int = 2, h: int = 128, v: int = 128):
super().__init__()
self.v = v
self.embed = nn.Embedding(v, h)
self.net = nn.Sequential(
nn.Linear(dim * h + 1, h), nn.ELU(),
nn.Linear(h, h), nn.ELU(),
nn.Linear(h, h), nn.ELU(),
nn.Linear(h, dim * v))
def forward(self, x_t: Tensor, t: Tensor) -> Tensor:
t = torch.ones_like(t)
return self.net(torch.cat((t[:, None], self.embed(x_t).flatten(1, 2)), -1)).reshape(list(x_t.shape) + [self.v])
batch_size = 512
vocab_size = 128
mask_id = vocab_size
expand_vocab_size = vocab_size + 1
model = DiscreteFlow(v=expand_vocab_size).to(device)
optim = torch.optim.Adam(model.parameters(), lr=0.001)
for i in range(10000):
x_1 = Tensor(make_moons(batch_size, noise=0.00)[0]).to(device)
x_1 = torch.round(torch.clip(x_1 * 35 + 50, min=0.0, max=vocab_size - 1)).long()
mask = torch.zeros_like(x_1) + mask_id
uniform = torch.randint(low=0, high=vocab_size, size=(batch_size, 2), device=device)
x_0 = torch.where(torch.rand(batch_size, 2, device=device) < 1, mask, uniform)
t = torch.rand(batch_size, device=device).clamp(min=0.01, max=1.0)
x_t = torch.where(torch.rand(batch_size, 2, device=device) < t[:, None], x_1, x_0)
logits = model(x_t, t)
loss_weight = 1/(1-t[:, None])
loss = nn.functional.cross_entropy(logits.flatten(0, 1), x_1.flatten(0, 1),)
loss = (loss * loss_weight.flatten()).mean()
if i % 200 == 0:
print(f"Iteration {i}, Loss: {loss.item():.2f}")
optim.zero_grad()
loss.backward()
optim.step()
# Sampling
mask = torch.zeros_like(x_1) + mask_id
uniform = torch.randint(low=0, high=vocab_size, size=(batch_size, 2), device=device)
x_t = torch.where(torch.rand(batch_size, 2, device=device) < 1, mask, uniform)
t = 0.0
results = [(x_t.cpu(), t)] # Move to CPU for plotting
while t < 1.0 - 1e-3:
p1 = torch.softmax(model(x_t, torch.ones(x_t.shape[0], device=device) * t), dim=-1)
h = min(0.1, 1.0 - t)
one_hot_x_t = nn.functional.one_hot(x_t, expand_vocab_size).float()
u = (p1 - one_hot_x_t) / (1.0 - t)
x_t = torch.distributions.Categorical(probs=one_hot_x_t + h * u).sample()
t += h
results.append((x_t.cpu(), t)) # Move to CPU for plotting
fig, axes = plt.subplots(1, len(results), figsize=(15, 2), sharex=True, sharey=True)
for (x_t, t), ax in zip(results, axes):
ax.scatter(x_t.detach()[:, 0], x_t.detach()[:, 1], s=10)
ax.set_title(f't={t:.2f}')
plt.tight_layout()
plt.savefig('discrete_flow_matching.png')
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment