Skip to content

Instantly share code, notes, and snippets.

@wanchaol
Created April 15, 2024 23:32
Show Gist options
  • Save wanchaol/bcf32b87edda474ea47267861291ad37 to your computer and use it in GitHub Desktop.
Save wanchaol/bcf32b87edda474ea47267861291ad37 to your computer and use it in GitHub Desktop.
import copy
from dataclasses import dataclass
from typing import Callable, Optional, Tuple
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed._tensor import init_device_mesh
from torch.distributed._tensor import distribute_tensor, DTensor, Shard
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, StateDictType
import os
world_size = int(os.environ["WORLD_SIZE"])
class TestUneven(nn.Module):
def __init__(self) -> None:
super().__init__()
self.uneven_param = torch.nn.Parameter(torch.randn(172032, 6144))
def forward(self):
return self.uneven_param
mesh_2d = init_device_mesh(
"cpu", (world_size // 8, 8), mesh_dim_names=("dp", "tp")
)
model = TestUneven()
tp_mesh = mesh_2d["tp"]
dp_mesh = mesh_2d["dp"]
from torch.distributed._tensor import distribute_tensor
model.uneven_param = torch.nn.Parameter(distribute_tensor(model.uneven_param, tp_mesh, [Shard(0)]))
model_2d = FSDP(model, device_mesh=dp_mesh, use_orig_params=True, device_id=torch.device("cpu"))
FSDP.set_state_dict_type(
model_2d,
StateDictType.SHARDED_STATE_DICT,
)
state_dict_2d = model_2d.state_dict()
CHECKPOINT_DIR = "/tmp/test_uneven_save_load"
import torch.distributed.checkpoint as dist_cp
# save: torchrun --standalone --nproc_per_node=48 ~/local/scratch/test_state_dict.py
# load: torchrun --standalone --nproc_per_node=40 ~/local/scratch/test_state_dict.py
# dist_cp.save(
# state_dict=state_dict_2d,
# storage_writer=dist_cp.FileSystemWriter(path=CHECKPOINT_DIR),
# planner=dist_cp.DefaultSavePlanner(),
# )
dist_cp.load(
state_dict=state_dict_2d,
storage_reader=dist_cp.FileSystemReader(CHECKPOINT_DIR),
planner=dist_cp.DefaultLoadPlanner(),
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment