Created
March 11, 2025 22:37
-
-
Save pengzhangzhi/4badb0f7d219ee9016f08d60c1da2b10 to your computer and use it in GitHub Desktop.
standard alone discrete diffusion model example on 2D data.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import matplotlib.pyplot as plt | |
from torch import nn, Tensor | |
from sklearn.datasets import make_moons | |
# Set device | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
class DiscreteFlow(nn.Module): | |
def __init__(self, dim: int = 2, h: int = 128, v: int = 128): | |
super().__init__() | |
self.v = v | |
self.embed = nn.Embedding(v, h) | |
self.net = nn.Sequential( | |
nn.Linear(dim * h + 1, h), nn.ELU(), | |
nn.Linear(h, h), nn.ELU(), | |
nn.Linear(h, h), nn.ELU(), | |
nn.Linear(h, dim * v)) | |
def forward(self, x_t: Tensor, t: Tensor) -> Tensor: | |
t = torch.ones_like(t) | |
return self.net(torch.cat((t[:, None], self.embed(x_t).flatten(1, 2)), -1)).reshape(list(x_t.shape) + [self.v]) | |
batch_size = 512 | |
vocab_size = 128 | |
mask_id = vocab_size | |
expand_vocab_size = vocab_size + 1 | |
model = DiscreteFlow(v=expand_vocab_size).to(device) | |
optim = torch.optim.Adam(model.parameters(), lr=0.001) | |
for i in range(10000): | |
x_1 = Tensor(make_moons(batch_size, noise=0.00)[0]).to(device) | |
x_1 = torch.round(torch.clip(x_1 * 35 + 50, min=0.0, max=vocab_size - 1)).long() | |
mask = torch.zeros_like(x_1) + mask_id | |
uniform = torch.randint(low=0, high=vocab_size, size=(batch_size, 2), device=device) | |
x_0 = torch.where(torch.rand(batch_size, 2, device=device) < 1, mask, uniform) | |
t = torch.rand(batch_size, device=device).clamp(min=0.01, max=1.0) | |
x_t = torch.where(torch.rand(batch_size, 2, device=device) < t[:, None], x_1, x_0) | |
logits = model(x_t, t) | |
loss_weight = 1/(1-t[:, None]) | |
loss = nn.functional.cross_entropy(logits.flatten(0, 1), x_1.flatten(0, 1),) | |
loss = (loss * loss_weight.flatten()).mean() | |
if i % 200 == 0: | |
print(f"Iteration {i}, Loss: {loss.item():.2f}") | |
optim.zero_grad() | |
loss.backward() | |
optim.step() | |
# Sampling | |
mask = torch.zeros_like(x_1) + mask_id | |
uniform = torch.randint(low=0, high=vocab_size, size=(batch_size, 2), device=device) | |
x_t = torch.where(torch.rand(batch_size, 2, device=device) < 1, mask, uniform) | |
t = 0.0 | |
results = [(x_t.cpu(), t)] # Move to CPU for plotting | |
while t < 1.0 - 1e-3: | |
p1 = torch.softmax(model(x_t, torch.ones(x_t.shape[0], device=device) * t), dim=-1) | |
h = min(0.1, 1.0 - t) | |
one_hot_x_t = nn.functional.one_hot(x_t, expand_vocab_size).float() | |
u = (p1 - one_hot_x_t) / (1.0 - t) | |
x_t = torch.distributions.Categorical(probs=one_hot_x_t + h * u).sample() | |
t += h | |
results.append((x_t.cpu(), t)) # Move to CPU for plotting | |
fig, axes = plt.subplots(1, len(results), figsize=(15, 2), sharex=True, sharey=True) | |
for (x_t, t), ax in zip(results, axes): | |
ax.scatter(x_t.detach()[:, 0], x_t.detach()[:, 1], s=10) | |
ax.set_title(f't={t:.2f}') | |
plt.tight_layout() | |
plt.savefig('discrete_flow_matching.png') | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment