Created
November 5, 2025 02:32
-
-
Save lodestone-rock/e0491d3d1c46c491a43ab28a5baa2f21 to your computer and use it in GitHub Desktop.
ramtorch single node multi gpu example
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import os | |
| import wandb | |
| from tqdm import tqdm | |
| import argparse | |
| import torch | |
| import torchvision | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.utils.data import DataLoader | |
| from torch.utils.checkpoint import checkpoint | |
| import torch.distributed as dist | |
| import torch.multiprocessing as mp | |
| from torch.utils.data.distributed import DistributedSampler | |
| from torch.profiler import profile, record_function, ProfilerActivity | |
| import numpy as np | |
| from torchvision.utils import save_image, make_grid | |
| from torchvision import datasets, transforms | |
| from ramtorch import AdamW, Linear | |
| from ramtorch.helpers import replace_linear_with_ramtorch | |
| from ramtorch.zero1 import create_zero_param_groups, broadcast_zero_params | |
| from ramtorch.zero2 import setup_grad_sharding_hooks | |
| from einops import rearrange | |
| import math | |
| def rotate_half(x): | |
| x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] | |
| return torch.cat((-x2, x1), dim=-1) | |
| def apply_rotary_pos_emb(q, k, cos, sin, offset=0): | |
| q_embed = (q * cos) + (rotate_half(q) * sin) | |
| k_embed = (k * cos) + (rotate_half(k) * sin) | |
| return q_embed, k_embed | |
| def soft_clamp(x, scale, alpha, shift): | |
| return scale * F.tanh(x * alpha) + shift | |
| class RotaryPositionalEmbedding(nn.Module): | |
| def __init__(self, dim, max_seq_len=2048): | |
| super().__init__() | |
| inv_freq = 1.0 / (max_seq_len ** (torch.arange(0, dim, 2).float() / dim)) | |
| t = torch.arange(max_seq_len).type_as(inv_freq) | |
| freqs = torch.einsum("i,j->ij", t, inv_freq) | |
| emb = torch.cat((freqs, freqs), dim=-1) | |
| self.register_buffer( | |
| "cos_cached", emb.cos()[None, None, :, :], persistent=False | |
| ) | |
| self.register_buffer( | |
| "sin_cached", emb.sin()[None, None, :, :], persistent=False | |
| ) | |
| def forward(self, x, seq_dim=1): | |
| return ( | |
| self.cos_cached[:, :, : x.shape[seq_dim], :], | |
| self.sin_cached[:, :, : x.shape[seq_dim], :], | |
| ) | |
| class SoftClamp(nn.Module): | |
| def __init__(self, dim): | |
| super().__init__() | |
| self.alpha = nn.Parameter(torch.ones(1) * 0.5) | |
| self.scale = nn.Parameter(torch.ones(dim)) | |
| self.shift = nn.Parameter(torch.zeros(dim)) | |
| self.use_compiled = False | |
| def forward(self, x): | |
| if self.use_compiled: | |
| return torch.compile(soft_clamp)(x, self.scale, self.alpha, self.shift) | |
| else: | |
| return soft_clamp(x, self.scale, self.alpha, self.shift) | |
| class AttentionBlock(nn.Module): | |
| def __init__(self, dim, num_heads=8, max_seq_len=2048, use_rope=True): | |
| super(AttentionBlock, self).__init__() | |
| self.num_heads = num_heads | |
| self.head_dim = dim // num_heads | |
| self.scale = self.head_dim**-0.5 | |
| self.qkv = nn.Linear(dim, dim * 3, bias=False) | |
| self.wo = nn.Linear(dim, dim, bias=True) | |
| self.layer_norm = SoftClamp(dim) | |
| self.rope = RotaryPositionalEmbedding(self.head_dim, max_seq_len) | |
| self.q_norm = SoftClamp(dim) | |
| self.k_norm = SoftClamp(dim) | |
| self.add_module("layer_norm", self.layer_norm) | |
| nn.init.zeros_(self.wo.weight) | |
| self.use_compiled = False | |
| self.use_rope = use_rope | |
| @property | |
| def device(self): | |
| return next(self.parameters()).device | |
| def forward(self, x, attention_mask=None): | |
| residual = x | |
| x = self.layer_norm(x) | |
| qkv = self.qkv(x) | |
| q, k, v = qkv.chunk(3, dim=-1) | |
| q = self.q_norm(q) | |
| k = self.k_norm(k) | |
| q = q.view(q.shape[0], q.shape[1], self.num_heads, self.head_dim).transpose( | |
| 1, 2 | |
| ) | |
| k = k.view(k.shape[0], k.shape[1], self.num_heads, self.head_dim).transpose( | |
| 1, 2 | |
| ) | |
| v = v.view(v.shape[0], v.shape[1], self.num_heads, self.head_dim).transpose( | |
| 1, 2 | |
| ) | |
| cos, sin = self.rope(x, seq_dim=1) | |
| if self.use_rope: | |
| if self.use_compiled: | |
| q, k = torch.compile(apply_rotary_pos_emb)(q, k, cos, sin) | |
| else: | |
| q, k = apply_rotary_pos_emb(q, k, cos, sin) | |
| out = torch.nn.functional.scaled_dot_product_attention( | |
| q, k, v, attn_mask=attention_mask | |
| ) | |
| out = out.transpose(1, 2).contiguous().view(x.shape[0], x.shape[1], -1) | |
| out = self.wo(out) | |
| return out + residual | |
| class GLU(nn.Module): | |
| def __init__(self, dim, exp_fac=4): | |
| super(GLU, self).__init__() | |
| self.wi_0 = nn.Linear(dim, dim * exp_fac, bias=False) | |
| self.wi_1 = nn.Linear(dim, dim * exp_fac, bias=False) | |
| self.wo = nn.Linear(dim * exp_fac, dim, bias=True) | |
| self.layer_norm = nn.LayerNorm(dim, elementwise_affine=False) | |
| nn.init.zeros_(self.wo.weight) | |
| self.use_compiled = False | |
| @property | |
| def device(self): | |
| return next(self.parameters()).device | |
| def _fwd_glu(self, x, residual): | |
| return self.wo(F.silu(self.wi_0(x)) * self.wi_1(x)) + residual | |
| def forward(self, x): | |
| residual = x | |
| x = self.layer_norm(x) | |
| if self.use_compiled: | |
| return torch.compile(self._fwd_glu)(x, residual) | |
| else: | |
| return self._fwd_glu(x, residual) | |
| class TransformerNetwork(nn.Module): | |
| def __init__( | |
| self, | |
| input_dim, | |
| output_dim, | |
| dim, | |
| num_layers, | |
| num_heads=8, | |
| exp_fac=4, | |
| rope_seq_length=2048, | |
| use_rope=True, | |
| final_head=True, | |
| input_proj=True, | |
| ): | |
| super(TransformerNetwork, self).__init__() | |
| if input_proj: | |
| self.input_layer = nn.Linear(input_dim, dim) | |
| else: | |
| self.input_layer = nn.Identity() | |
| input_dim = dim | |
| self.blocks = nn.ModuleList( | |
| [ | |
| nn.ModuleDict( | |
| { | |
| "attn": AttentionBlock( | |
| dim, num_heads, rope_seq_length, use_rope | |
| ), | |
| "glu": GLU(dim, exp_fac), | |
| } | |
| ) | |
| for _ in range(num_layers) | |
| ] | |
| ) | |
| self.out_norm = SoftClamp(dim) | |
| if final_head: | |
| self.output_layer = nn.Linear(dim, output_dim) | |
| else: | |
| self.output_layer = nn.Identity() | |
| def set_use_compiled(self): | |
| for name, module in self.named_modules(): | |
| if hasattr(module, "use_compiled"): | |
| print(f"Setting 'use_compiled' to True in module: {name}") | |
| setattr(module, "use_compiled", True) | |
| def run_attn(self, block, x, mask): | |
| return block["attn"](x, mask) | |
| def forward(self, x, attention_mask=None, act_ckpt=True): | |
| if act_ckpt: | |
| x = checkpoint(self.input_layer, x) | |
| for block in self.blocks: | |
| if type(block) == nn.Identity: | |
| continue | |
| x = checkpoint(self.run_attn, block, x, attention_mask) | |
| x = checkpoint(block["glu"], x) | |
| x = checkpoint(self.out_norm, x) | |
| x = checkpoint(self.output_layer, x) | |
| else: | |
| x = self.input_layer(x) | |
| for block in self.blocks: | |
| if type(block) == nn.Identity: | |
| continue | |
| x = block["attn"](x, attention_mask) | |
| x = block["glu"](x) | |
| x = self.out_norm(x) | |
| x = self.output_layer(x) | |
| return x | |
| def image_flatten(latents, shuffle_size=2): | |
| """Convert image from NCHW to flattened sequence with pixel shuffling.""" | |
| return ( | |
| rearrange( | |
| latents, | |
| "n c (h dh) (w dw) -> n (h w) (c dh dw)", | |
| dh=shuffle_size, | |
| dw=shuffle_size, | |
| ), | |
| latents.shape, | |
| ) | |
| def image_unflatten(latents, shape, shuffle_size=2): | |
| """Reverse of image_flatten: reconstruct NCHW from flattened sequence.""" | |
| n, c, h, w = shape | |
| return rearrange( | |
| latents, | |
| "n (h w) (c dh dw) -> n c (h dh) (w dw)", | |
| dh=shuffle_size, | |
| dw=shuffle_size, | |
| c=c, | |
| h=h // shuffle_size, | |
| w=w // shuffle_size, | |
| ) | |
| def sample_from_distribution(x, probabilities, n): | |
| indices = torch.multinomial(probabilities, n, replacement=True) | |
| return x[indices] | |
| def create_distribution(num_points, device=None): | |
| """Create custom probability distribution for timestep sampling.""" | |
| x = torch.linspace(0, 1, num_points, device=device) | |
| probabilities = -7.7 * ((x - 0.5) ** 2) + 2 | |
| probabilities /= probabilities.sum() | |
| return x, probabilities | |
| def repeat_along_dim(tensor, repeats, dim): | |
| """Repeat tensor elements along a specified dimension.""" | |
| permute_order = list(range(tensor.dim())) | |
| permute_order[dim], permute_order[0] = permute_order[0], permute_order[dim] | |
| tensor = tensor.permute(permute_order) | |
| tensor = tensor.unsqueeze(1) | |
| repeated_tensor = tensor.repeat(1, repeats, *([1] * (tensor.dim() - 2))) | |
| repeated_tensor = repeated_tensor.view(-1, *repeated_tensor.shape[2:]) | |
| permute_order[dim], permute_order[0] = permute_order[0], permute_order[dim] | |
| repeated_tensor = repeated_tensor.permute(permute_order) | |
| return repeated_tensor | |
| class Flow(nn.Module): | |
| """ | |
| A flow-matching model based on a Transformer architecture. | |
| This model learns the vector field that transports a noise distribution (z) | |
| to a data distribution (x1) over a normalized time range [0, 1]. | |
| It uses classifier-free guidance for conditional generation. | |
| """ | |
| def __init__( | |
| self, | |
| input_dim, | |
| output_dim, | |
| dim, | |
| num_layers, | |
| num_heads=8, | |
| exp_fac=4, | |
| rope_seq_length=784, | |
| class_count=10, | |
| cond_seq_len=40, | |
| ): | |
| super().__init__() | |
| self.dim = dim | |
| self.class_count = class_count | |
| self.input_layer = nn.Linear(input_dim, dim) | |
| self.timestep_vector = nn.Linear(1, dim) | |
| self.class_embed = nn.Linear(cond_seq_len, dim) | |
| self.class_norm = SoftClamp(dim=dim) | |
| self.cond_seq_len = cond_seq_len | |
| self.transformer = TransformerNetwork( | |
| input_dim=dim, | |
| output_dim=output_dim, | |
| dim=dim, | |
| num_layers=num_layers, | |
| num_heads=num_heads, | |
| exp_fac=exp_fac, | |
| rope_seq_length=rope_seq_length, | |
| final_head=True, | |
| input_proj=False, | |
| ) | |
| self._init_weights() | |
| def _init_weights(self): | |
| """Initialize weights for stability.""" | |
| nn.init.zeros_(self.timestep_vector.weight) | |
| nn.init.zeros_(self.timestep_vector.bias) | |
| nn.init.zeros_(self.class_embed.weight) | |
| nn.init.zeros_(self.class_embed.bias) | |
| @property | |
| def device(self): | |
| return torch.cuda.current_device() | |
| def forward(self, x, t, condition, attention_mask=None): | |
| """ | |
| Forward pass for the flow model. | |
| Args: | |
| x (torch.Tensor): Input tensor of shape [B, SeqLen, Dim]. | |
| t (torch.Tensor): Time step tensor of shape [B, 1]. | |
| condition (torch.Tensor): Class condition tensor of shape [B]. | |
| attention_mask (torch.Tensor, optional): Mask for attention. | |
| Returns: | |
| torch.Tensor: The predicted velocity vector. | |
| """ | |
| x_proj = self.input_layer(x) | |
| time_vec = self.timestep_vector(t.view(-1, 1)).unsqueeze(1) | |
| class_vec = self.class_embed(condition.to(self.class_embed.weight.dtype))[ | |
| :, None, : | |
| ] | |
| class_vec = self.class_norm(class_vec) | |
| # Concatenate conditioning tokens with sequence tokens | |
| tokens = torch.cat((time_vec, class_vec, x_proj), dim=1) | |
| output_tokens = self.transformer(tokens, attention_mask) | |
| # Extract only the sequence tokens (skip time and class tokens) | |
| velocity_pred = output_tokens[:, 2:, ...] | |
| return velocity_pred | |
| def compute_loss(self, x1, x0, condition, class_dropout_ratio=0.1): | |
| """Calculate the rectified flow loss for a batch.""" | |
| B = x1.shape[0] | |
| # Apply classifier-free guidance dropout | |
| cond_clone = condition.clone() | |
| is_dropped = torch.rand(B, device=self.device) < class_dropout_ratio | |
| cond_clone[is_dropped] = 0 | |
| # Sample timesteps from custom distribution | |
| num_points = 1000 | |
| x, probabilities = create_distribution(num_points, device=self.device) | |
| t = sample_from_distribution(x, probabilities, B)[:, None, None] | |
| t = t.to(x0.dtype) | |
| # Interpolate between noise and data | |
| xt = x0 * (1 - t) + x1 * t | |
| target_velocity = x1 - x0 | |
| predicted_velocity = self.forward(xt, t, cond_clone) | |
| loss = F.mse_loss(predicted_velocity, target_velocity) | |
| return loss | |
| def euler_cfg( | |
| self, | |
| x, | |
| pos_cond, | |
| cfg_scale=4.0, | |
| num_steps=100, | |
| skip_last_n=0, | |
| return_intermediates=False, | |
| ): | |
| """Euler method sampler with classifier-free guidance.""" | |
| if return_intermediates: | |
| trajectories = [x.cpu()] | |
| else: | |
| trajectories = None | |
| neg_cond = torch.zeros_like(pos_cond) | |
| dt = 1.0 / num_steps | |
| effective_steps = num_steps - skip_last_n | |
| for i in tqdm(range(effective_steps), desc="Euler CFG Sampling"): | |
| with torch.no_grad(): | |
| t_val = i * dt | |
| t = torch.ones(x.shape[0], 1).to(self.device, x.dtype) * t_val | |
| v_pos = self.forward(x, t, pos_cond) | |
| v_neg = self.forward(x, t, neg_cond) | |
| velocity = v_neg + cfg_scale * (v_pos - v_neg) | |
| x = x + velocity * dt | |
| if return_intermediates: | |
| trajectories.append(x.cpu()) | |
| return x, trajectories | |
| def setup(rank, world_size): | |
| os.environ["MASTER_ADDR"] = "localhost" | |
| os.environ["MASTER_PORT"] = "12355" | |
| dist.init_process_group("nccl", rank=rank, world_size=world_size) | |
| torch.cuda.set_device(rank) | |
| def cleanup(): | |
| dist.destroy_process_group() | |
| def main(rank, world_size, training_config, model): | |
| latch = True | |
| setup(rank, world_size) | |
| torch.manual_seed(0) | |
| transform = transforms.Compose( | |
| [ | |
| transforms.Resize(32), | |
| transforms.CenterCrop(32), | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), | |
| ] | |
| ) | |
| dataset = datasets.CelebA( | |
| root="celeba/", split="train", transform=transform, download=True | |
| ) | |
| sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank) | |
| loader = DataLoader( | |
| dataset, | |
| batch_size=training_config["batch_size"], | |
| shuffle=False, | |
| sampler=sampler, | |
| prefetch_factor=3, | |
| num_workers=1, | |
| ) | |
| DEVICE = rank | |
| # Replace standard Linear layers with ramtorch offloaded layers | |
| # ramtorch stores weights on CPU (shared across workers) and streams to GPU on-demand | |
| model = replace_linear_with_ramtorch(model, rank) | |
| model.to(rank) | |
| # Shard optimizer states across workers (ZeRO-1 style) | |
| # Each worker only maintains optimizer states for a subset of parameters | |
| all_params = list(model.parameters()) | |
| param_groups = [{"params": all_params, "lr": training_config["lr"]}] | |
| rank_param_groups = create_zero_param_groups(param_groups, world_size) | |
| # Setup gradient sharding hooks (ZeRO-2 style) | |
| # Gradients are partitioned and only stored on the worker (as cpu buffer) responsible for updating them | |
| setup_grad_sharding_hooks(rank_param_groups, rank) | |
| optim = AdamW(rank_param_groups[rank]) | |
| lr_scheduler = torch.optim.lr_scheduler.LinearLR( | |
| optim, start_factor=1e-5, end_factor=1.0, total_iters=100 | |
| ) | |
| if rank == 0 and training_config["wandb_project"]: | |
| wandb.login(key=training_config["wandb_key"], host="https://api.wandb.ai") | |
| wandb.init( | |
| project=training_config["wandb_project"], | |
| name=training_config["preview_path"], | |
| ) | |
| with profile( | |
| activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], | |
| record_shapes=True, | |
| profile_memory=True, | |
| with_stack=True, | |
| with_flops=True, | |
| ) as prof: | |
| for epoch in range(training_config["num_epochs"]): | |
| torch.manual_seed(epoch) | |
| sampler.set_epoch(epoch) | |
| progress_bar = tqdm( | |
| total=len(loader), desc="Processing", smoothing=0.1, disable=rank != 0 | |
| ) | |
| for batch_idx, (real, label) in enumerate(loader): | |
| real = real.to(DEVICE) | |
| label = label.to(DEVICE) | |
| x1, image_shape = image_flatten(real) | |
| x1 = x1.requires_grad_(True) | |
| x0 = torch.randn_like(x1) | |
| with torch.autocast("cuda", torch.bfloat16): | |
| loss = ( | |
| model.compute_loss( | |
| x1=x1, | |
| x0=x0, | |
| condition=label, | |
| class_dropout_ratio=training_config["class_dropout_ratio"], | |
| ) | |
| / world_size | |
| ) | |
| # Synchronize before backward to ensure all workers are ready | |
| torch.cuda.synchronize() | |
| loss.backward() | |
| torch.cuda.synchronize() | |
| optim.step() | |
| # Broadcast non-ramtorch parameters across workers | |
| # ramtorch parameters are already shared via CPU memory, but standard | |
| # PyTorch parameters need explicit broadcasting to stay synchronized | |
| broadcast_zero_params(rank_param_groups) | |
| lr_scheduler.step() | |
| # Use model.zero_grad() instead of optimizer.zero_grad() | |
| # Each worker only handles partial gradients, so we need to zero | |
| # gradients at the model level to properly clear all workers' buffers | |
| model.zero_grad() | |
| torch.cuda.synchronize() | |
| if rank == 0: | |
| progress_bar.set_description( | |
| f"Epoch [{epoch}/{training_config['num_epochs']}] Step [{batch_idx}/{len(loader)}] Loss: {loss * world_size:.4f}" | |
| ) | |
| if training_config["wandb_project"]: | |
| wandb.log({"Loss": loss.detach() * world_size, "Epoch": epoch}) | |
| if rank == 0 and batch_idx % training_config["eval_interval"] == 0: | |
| with torch.no_grad(): | |
| z = torch.randn_like(x1) | |
| with torch.autocast("cuda", torch.bfloat16): | |
| fake_images_list = [] | |
| for cfg, steps in training_config[ | |
| "inference_cfg_and_steps" | |
| ]: | |
| fake_images_cfg, _ = model.euler_cfg( | |
| z, label, cfg, num_steps=steps | |
| ) | |
| fake_images_list.append(fake_images_cfg) | |
| real_unflattened = image_unflatten(x1, image_shape) | |
| fake_images_unflattened = [ | |
| image_unflatten(img, image_shape) | |
| for img in fake_images_list | |
| ] | |
| all_images = torch.cat( | |
| fake_images_unflattened + [real_unflattened], dim=0 | |
| ) | |
| os.makedirs(training_config["preview_path"], exist_ok=True) | |
| img_path = f"{training_config['preview_path']}/epoch_{epoch}_{batch_idx}.jpg" | |
| save_image( | |
| make_grid( | |
| (all_images.clip(-1, 1) + 1) / 2, | |
| nrow=training_config["batch_size"], | |
| ), | |
| img_path, | |
| ) | |
| if training_config["wandb_project"]: | |
| wandb.log({"example_image": wandb.Image(img_path)}) | |
| if rank == 0: | |
| progress_bar.update(1) | |
| profile_steps = 15 | |
| if batch_idx == profile_steps and latch: | |
| prof.stop() | |
| prof.export_chrome_trace(f"trace_{rank}.json") | |
| print(f"Stopped profiling and saved trace at step {batch_idx}") | |
| latch = False | |
| if rank == 0: | |
| os.makedirs(training_config["ckpt_path"], exist_ok=True) | |
| torch.save( | |
| model.state_dict(), f"{training_config['ckpt_path']}/{epoch}.pth" | |
| ) | |
| cleanup() | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| world_size = torch.cuda.device_count() | |
| training_config = { | |
| "batch_size": 64 * 4, | |
| "lr": 1e-4 * (4 ** (1 / 2)), | |
| "num_epochs": 1000, | |
| "eval_interval": 100, | |
| "preview_path": "celeba_flowmatching_DiT-B/2_ramtorch_full_integration_test", | |
| "wandb_project": "flowmatch_celeba_2_ramtorch_full_integration_test", | |
| "wandb_key": None, | |
| "ckpt_path": "celeba_flowmatching_DiT-B/2_ramtorch_full_integration_test", | |
| "class_dropout_ratio": 0.1, | |
| "model_config": { | |
| "input_dim": 3 * 4, | |
| "output_dim": 3 * 4, | |
| "dim": 768, | |
| "num_layers": 12, | |
| "num_heads": 12, | |
| "exp_fac": 4, | |
| "rope_seq_length": 64**2 + 30, | |
| "class_count": 40, | |
| }, | |
| "inference_cfg_and_steps": [ | |
| [1, 30], | |
| [3, 30], | |
| ], | |
| "model_checkpoint": None, | |
| } | |
| # Model must be instantiated before spawning GPU workers | |
| # ramtorch shares CPU tensors across workers, so the model needs to exist | |
| # in the parent process before forking to enable proper memory sharing | |
| model = Flow(**training_config["model_config"]) | |
| if training_config["model_checkpoint"]: | |
| model.load_state_dict(torch.load(training_config["model_checkpoint"])) | |
| mp.spawn(main, args=(world_size, training_config, model), nprocs=world_size) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment