Skip to content

Instantly share code, notes, and snippets.

@lodestone-rock
Created November 5, 2025 02:32
Show Gist options
  • Save lodestone-rock/e0491d3d1c46c491a43ab28a5baa2f21 to your computer and use it in GitHub Desktop.
Save lodestone-rock/e0491d3d1c46c491a43ab28a5baa2f21 to your computer and use it in GitHub Desktop.
ramtorch single node multi gpu example
import os
import wandb
from tqdm import tqdm
import argparse
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.checkpoint import checkpoint
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.utils.data.distributed import DistributedSampler
from torch.profiler import profile, record_function, ProfilerActivity
import numpy as np
from torchvision.utils import save_image, make_grid
from torchvision import datasets, transforms
from ramtorch import AdamW, Linear
from ramtorch.helpers import replace_linear_with_ramtorch
from ramtorch.zero1 import create_zero_param_groups, broadcast_zero_params
from ramtorch.zero2 import setup_grad_sharding_hooks
from einops import rearrange
import math
def rotate_half(x):
x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, offset=0):
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
def soft_clamp(x, scale, alpha, shift):
return scale * F.tanh(x * alpha) + shift
class RotaryPositionalEmbedding(nn.Module):
def __init__(self, dim, max_seq_len=2048):
super().__init__()
inv_freq = 1.0 / (max_seq_len ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(max_seq_len).type_as(inv_freq)
freqs = torch.einsum("i,j->ij", t, inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer(
"cos_cached", emb.cos()[None, None, :, :], persistent=False
)
self.register_buffer(
"sin_cached", emb.sin()[None, None, :, :], persistent=False
)
def forward(self, x, seq_dim=1):
return (
self.cos_cached[:, :, : x.shape[seq_dim], :],
self.sin_cached[:, :, : x.shape[seq_dim], :],
)
class SoftClamp(nn.Module):
def __init__(self, dim):
super().__init__()
self.alpha = nn.Parameter(torch.ones(1) * 0.5)
self.scale = nn.Parameter(torch.ones(dim))
self.shift = nn.Parameter(torch.zeros(dim))
self.use_compiled = False
def forward(self, x):
if self.use_compiled:
return torch.compile(soft_clamp)(x, self.scale, self.alpha, self.shift)
else:
return soft_clamp(x, self.scale, self.alpha, self.shift)
class AttentionBlock(nn.Module):
def __init__(self, dim, num_heads=8, max_seq_len=2048, use_rope=True):
super(AttentionBlock, self).__init__()
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim**-0.5
self.qkv = nn.Linear(dim, dim * 3, bias=False)
self.wo = nn.Linear(dim, dim, bias=True)
self.layer_norm = SoftClamp(dim)
self.rope = RotaryPositionalEmbedding(self.head_dim, max_seq_len)
self.q_norm = SoftClamp(dim)
self.k_norm = SoftClamp(dim)
self.add_module("layer_norm", self.layer_norm)
nn.init.zeros_(self.wo.weight)
self.use_compiled = False
self.use_rope = use_rope
@property
def device(self):
return next(self.parameters()).device
def forward(self, x, attention_mask=None):
residual = x
x = self.layer_norm(x)
qkv = self.qkv(x)
q, k, v = qkv.chunk(3, dim=-1)
q = self.q_norm(q)
k = self.k_norm(k)
q = q.view(q.shape[0], q.shape[1], self.num_heads, self.head_dim).transpose(
1, 2
)
k = k.view(k.shape[0], k.shape[1], self.num_heads, self.head_dim).transpose(
1, 2
)
v = v.view(v.shape[0], v.shape[1], self.num_heads, self.head_dim).transpose(
1, 2
)
cos, sin = self.rope(x, seq_dim=1)
if self.use_rope:
if self.use_compiled:
q, k = torch.compile(apply_rotary_pos_emb)(q, k, cos, sin)
else:
q, k = apply_rotary_pos_emb(q, k, cos, sin)
out = torch.nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=attention_mask
)
out = out.transpose(1, 2).contiguous().view(x.shape[0], x.shape[1], -1)
out = self.wo(out)
return out + residual
class GLU(nn.Module):
def __init__(self, dim, exp_fac=4):
super(GLU, self).__init__()
self.wi_0 = nn.Linear(dim, dim * exp_fac, bias=False)
self.wi_1 = nn.Linear(dim, dim * exp_fac, bias=False)
self.wo = nn.Linear(dim * exp_fac, dim, bias=True)
self.layer_norm = nn.LayerNorm(dim, elementwise_affine=False)
nn.init.zeros_(self.wo.weight)
self.use_compiled = False
@property
def device(self):
return next(self.parameters()).device
def _fwd_glu(self, x, residual):
return self.wo(F.silu(self.wi_0(x)) * self.wi_1(x)) + residual
def forward(self, x):
residual = x
x = self.layer_norm(x)
if self.use_compiled:
return torch.compile(self._fwd_glu)(x, residual)
else:
return self._fwd_glu(x, residual)
class TransformerNetwork(nn.Module):
def __init__(
self,
input_dim,
output_dim,
dim,
num_layers,
num_heads=8,
exp_fac=4,
rope_seq_length=2048,
use_rope=True,
final_head=True,
input_proj=True,
):
super(TransformerNetwork, self).__init__()
if input_proj:
self.input_layer = nn.Linear(input_dim, dim)
else:
self.input_layer = nn.Identity()
input_dim = dim
self.blocks = nn.ModuleList(
[
nn.ModuleDict(
{
"attn": AttentionBlock(
dim, num_heads, rope_seq_length, use_rope
),
"glu": GLU(dim, exp_fac),
}
)
for _ in range(num_layers)
]
)
self.out_norm = SoftClamp(dim)
if final_head:
self.output_layer = nn.Linear(dim, output_dim)
else:
self.output_layer = nn.Identity()
def set_use_compiled(self):
for name, module in self.named_modules():
if hasattr(module, "use_compiled"):
print(f"Setting 'use_compiled' to True in module: {name}")
setattr(module, "use_compiled", True)
def run_attn(self, block, x, mask):
return block["attn"](x, mask)
def forward(self, x, attention_mask=None, act_ckpt=True):
if act_ckpt:
x = checkpoint(self.input_layer, x)
for block in self.blocks:
if type(block) == nn.Identity:
continue
x = checkpoint(self.run_attn, block, x, attention_mask)
x = checkpoint(block["glu"], x)
x = checkpoint(self.out_norm, x)
x = checkpoint(self.output_layer, x)
else:
x = self.input_layer(x)
for block in self.blocks:
if type(block) == nn.Identity:
continue
x = block["attn"](x, attention_mask)
x = block["glu"](x)
x = self.out_norm(x)
x = self.output_layer(x)
return x
def image_flatten(latents, shuffle_size=2):
"""Convert image from NCHW to flattened sequence with pixel shuffling."""
return (
rearrange(
latents,
"n c (h dh) (w dw) -> n (h w) (c dh dw)",
dh=shuffle_size,
dw=shuffle_size,
),
latents.shape,
)
def image_unflatten(latents, shape, shuffle_size=2):
"""Reverse of image_flatten: reconstruct NCHW from flattened sequence."""
n, c, h, w = shape
return rearrange(
latents,
"n (h w) (c dh dw) -> n c (h dh) (w dw)",
dh=shuffle_size,
dw=shuffle_size,
c=c,
h=h // shuffle_size,
w=w // shuffle_size,
)
def sample_from_distribution(x, probabilities, n):
indices = torch.multinomial(probabilities, n, replacement=True)
return x[indices]
def create_distribution(num_points, device=None):
"""Create custom probability distribution for timestep sampling."""
x = torch.linspace(0, 1, num_points, device=device)
probabilities = -7.7 * ((x - 0.5) ** 2) + 2
probabilities /= probabilities.sum()
return x, probabilities
def repeat_along_dim(tensor, repeats, dim):
"""Repeat tensor elements along a specified dimension."""
permute_order = list(range(tensor.dim()))
permute_order[dim], permute_order[0] = permute_order[0], permute_order[dim]
tensor = tensor.permute(permute_order)
tensor = tensor.unsqueeze(1)
repeated_tensor = tensor.repeat(1, repeats, *([1] * (tensor.dim() - 2)))
repeated_tensor = repeated_tensor.view(-1, *repeated_tensor.shape[2:])
permute_order[dim], permute_order[0] = permute_order[0], permute_order[dim]
repeated_tensor = repeated_tensor.permute(permute_order)
return repeated_tensor
class Flow(nn.Module):
"""
A flow-matching model based on a Transformer architecture.
This model learns the vector field that transports a noise distribution (z)
to a data distribution (x1) over a normalized time range [0, 1].
It uses classifier-free guidance for conditional generation.
"""
def __init__(
self,
input_dim,
output_dim,
dim,
num_layers,
num_heads=8,
exp_fac=4,
rope_seq_length=784,
class_count=10,
cond_seq_len=40,
):
super().__init__()
self.dim = dim
self.class_count = class_count
self.input_layer = nn.Linear(input_dim, dim)
self.timestep_vector = nn.Linear(1, dim)
self.class_embed = nn.Linear(cond_seq_len, dim)
self.class_norm = SoftClamp(dim=dim)
self.cond_seq_len = cond_seq_len
self.transformer = TransformerNetwork(
input_dim=dim,
output_dim=output_dim,
dim=dim,
num_layers=num_layers,
num_heads=num_heads,
exp_fac=exp_fac,
rope_seq_length=rope_seq_length,
final_head=True,
input_proj=False,
)
self._init_weights()
def _init_weights(self):
"""Initialize weights for stability."""
nn.init.zeros_(self.timestep_vector.weight)
nn.init.zeros_(self.timestep_vector.bias)
nn.init.zeros_(self.class_embed.weight)
nn.init.zeros_(self.class_embed.bias)
@property
def device(self):
return torch.cuda.current_device()
def forward(self, x, t, condition, attention_mask=None):
"""
Forward pass for the flow model.
Args:
x (torch.Tensor): Input tensor of shape [B, SeqLen, Dim].
t (torch.Tensor): Time step tensor of shape [B, 1].
condition (torch.Tensor): Class condition tensor of shape [B].
attention_mask (torch.Tensor, optional): Mask for attention.
Returns:
torch.Tensor: The predicted velocity vector.
"""
x_proj = self.input_layer(x)
time_vec = self.timestep_vector(t.view(-1, 1)).unsqueeze(1)
class_vec = self.class_embed(condition.to(self.class_embed.weight.dtype))[
:, None, :
]
class_vec = self.class_norm(class_vec)
# Concatenate conditioning tokens with sequence tokens
tokens = torch.cat((time_vec, class_vec, x_proj), dim=1)
output_tokens = self.transformer(tokens, attention_mask)
# Extract only the sequence tokens (skip time and class tokens)
velocity_pred = output_tokens[:, 2:, ...]
return velocity_pred
def compute_loss(self, x1, x0, condition, class_dropout_ratio=0.1):
"""Calculate the rectified flow loss for a batch."""
B = x1.shape[0]
# Apply classifier-free guidance dropout
cond_clone = condition.clone()
is_dropped = torch.rand(B, device=self.device) < class_dropout_ratio
cond_clone[is_dropped] = 0
# Sample timesteps from custom distribution
num_points = 1000
x, probabilities = create_distribution(num_points, device=self.device)
t = sample_from_distribution(x, probabilities, B)[:, None, None]
t = t.to(x0.dtype)
# Interpolate between noise and data
xt = x0 * (1 - t) + x1 * t
target_velocity = x1 - x0
predicted_velocity = self.forward(xt, t, cond_clone)
loss = F.mse_loss(predicted_velocity, target_velocity)
return loss
def euler_cfg(
self,
x,
pos_cond,
cfg_scale=4.0,
num_steps=100,
skip_last_n=0,
return_intermediates=False,
):
"""Euler method sampler with classifier-free guidance."""
if return_intermediates:
trajectories = [x.cpu()]
else:
trajectories = None
neg_cond = torch.zeros_like(pos_cond)
dt = 1.0 / num_steps
effective_steps = num_steps - skip_last_n
for i in tqdm(range(effective_steps), desc="Euler CFG Sampling"):
with torch.no_grad():
t_val = i * dt
t = torch.ones(x.shape[0], 1).to(self.device, x.dtype) * t_val
v_pos = self.forward(x, t, pos_cond)
v_neg = self.forward(x, t, neg_cond)
velocity = v_neg + cfg_scale * (v_pos - v_neg)
x = x + velocity * dt
if return_intermediates:
trajectories.append(x.cpu())
return x, trajectories
def setup(rank, world_size):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
dist.init_process_group("nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
def cleanup():
dist.destroy_process_group()
def main(rank, world_size, training_config, model):
latch = True
setup(rank, world_size)
torch.manual_seed(0)
transform = transforms.Compose(
[
transforms.Resize(32),
transforms.CenterCrop(32),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
)
dataset = datasets.CelebA(
root="celeba/", split="train", transform=transform, download=True
)
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
loader = DataLoader(
dataset,
batch_size=training_config["batch_size"],
shuffle=False,
sampler=sampler,
prefetch_factor=3,
num_workers=1,
)
DEVICE = rank
# Replace standard Linear layers with ramtorch offloaded layers
# ramtorch stores weights on CPU (shared across workers) and streams to GPU on-demand
model = replace_linear_with_ramtorch(model, rank)
model.to(rank)
# Shard optimizer states across workers (ZeRO-1 style)
# Each worker only maintains optimizer states for a subset of parameters
all_params = list(model.parameters())
param_groups = [{"params": all_params, "lr": training_config["lr"]}]
rank_param_groups = create_zero_param_groups(param_groups, world_size)
# Setup gradient sharding hooks (ZeRO-2 style)
# Gradients are partitioned and only stored on the worker (as cpu buffer) responsible for updating them
setup_grad_sharding_hooks(rank_param_groups, rank)
optim = AdamW(rank_param_groups[rank])
lr_scheduler = torch.optim.lr_scheduler.LinearLR(
optim, start_factor=1e-5, end_factor=1.0, total_iters=100
)
if rank == 0 and training_config["wandb_project"]:
wandb.login(key=training_config["wandb_key"], host="https://api.wandb.ai")
wandb.init(
project=training_config["wandb_project"],
name=training_config["preview_path"],
)
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
record_shapes=True,
profile_memory=True,
with_stack=True,
with_flops=True,
) as prof:
for epoch in range(training_config["num_epochs"]):
torch.manual_seed(epoch)
sampler.set_epoch(epoch)
progress_bar = tqdm(
total=len(loader), desc="Processing", smoothing=0.1, disable=rank != 0
)
for batch_idx, (real, label) in enumerate(loader):
real = real.to(DEVICE)
label = label.to(DEVICE)
x1, image_shape = image_flatten(real)
x1 = x1.requires_grad_(True)
x0 = torch.randn_like(x1)
with torch.autocast("cuda", torch.bfloat16):
loss = (
model.compute_loss(
x1=x1,
x0=x0,
condition=label,
class_dropout_ratio=training_config["class_dropout_ratio"],
)
/ world_size
)
# Synchronize before backward to ensure all workers are ready
torch.cuda.synchronize()
loss.backward()
torch.cuda.synchronize()
optim.step()
# Broadcast non-ramtorch parameters across workers
# ramtorch parameters are already shared via CPU memory, but standard
# PyTorch parameters need explicit broadcasting to stay synchronized
broadcast_zero_params(rank_param_groups)
lr_scheduler.step()
# Use model.zero_grad() instead of optimizer.zero_grad()
# Each worker only handles partial gradients, so we need to zero
# gradients at the model level to properly clear all workers' buffers
model.zero_grad()
torch.cuda.synchronize()
if rank == 0:
progress_bar.set_description(
f"Epoch [{epoch}/{training_config['num_epochs']}] Step [{batch_idx}/{len(loader)}] Loss: {loss * world_size:.4f}"
)
if training_config["wandb_project"]:
wandb.log({"Loss": loss.detach() * world_size, "Epoch": epoch})
if rank == 0 and batch_idx % training_config["eval_interval"] == 0:
with torch.no_grad():
z = torch.randn_like(x1)
with torch.autocast("cuda", torch.bfloat16):
fake_images_list = []
for cfg, steps in training_config[
"inference_cfg_and_steps"
]:
fake_images_cfg, _ = model.euler_cfg(
z, label, cfg, num_steps=steps
)
fake_images_list.append(fake_images_cfg)
real_unflattened = image_unflatten(x1, image_shape)
fake_images_unflattened = [
image_unflatten(img, image_shape)
for img in fake_images_list
]
all_images = torch.cat(
fake_images_unflattened + [real_unflattened], dim=0
)
os.makedirs(training_config["preview_path"], exist_ok=True)
img_path = f"{training_config['preview_path']}/epoch_{epoch}_{batch_idx}.jpg"
save_image(
make_grid(
(all_images.clip(-1, 1) + 1) / 2,
nrow=training_config["batch_size"],
),
img_path,
)
if training_config["wandb_project"]:
wandb.log({"example_image": wandb.Image(img_path)})
if rank == 0:
progress_bar.update(1)
profile_steps = 15
if batch_idx == profile_steps and latch:
prof.stop()
prof.export_chrome_trace(f"trace_{rank}.json")
print(f"Stopped profiling and saved trace at step {batch_idx}")
latch = False
if rank == 0:
os.makedirs(training_config["ckpt_path"], exist_ok=True)
torch.save(
model.state_dict(), f"{training_config['ckpt_path']}/{epoch}.pth"
)
cleanup()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
world_size = torch.cuda.device_count()
training_config = {
"batch_size": 64 * 4,
"lr": 1e-4 * (4 ** (1 / 2)),
"num_epochs": 1000,
"eval_interval": 100,
"preview_path": "celeba_flowmatching_DiT-B/2_ramtorch_full_integration_test",
"wandb_project": "flowmatch_celeba_2_ramtorch_full_integration_test",
"wandb_key": None,
"ckpt_path": "celeba_flowmatching_DiT-B/2_ramtorch_full_integration_test",
"class_dropout_ratio": 0.1,
"model_config": {
"input_dim": 3 * 4,
"output_dim": 3 * 4,
"dim": 768,
"num_layers": 12,
"num_heads": 12,
"exp_fac": 4,
"rope_seq_length": 64**2 + 30,
"class_count": 40,
},
"inference_cfg_and_steps": [
[1, 30],
[3, 30],
],
"model_checkpoint": None,
}
# Model must be instantiated before spawning GPU workers
# ramtorch shares CPU tensors across workers, so the model needs to exist
# in the parent process before forking to enable proper memory sharing
model = Flow(**training_config["model_config"])
if training_config["model_checkpoint"]:
model.load_state_dict(torch.load(training_config["model_checkpoint"]))
mp.spawn(main, args=(world_size, training_config, model), nprocs=world_size)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment