-
-
Save stas00/4824504176699bcc1009bed16d2b27ca to your computer and use it in GitHub Desktop.
MP4 SHARP bug (edited to support modern launcher and added some status printing to make it easier to see what's going on)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.distributed as dist | |
import os | |
local_rank = int(os.environ["LOCAL_RANK"]) | |
dist.init_process_group(backend='nccl') | |
torch.cuda.set_device(local_rank) | |
device = torch.device("cuda", local_rank) | |
world_size = dist.get_world_size() | |
model_parallel_size = 4 | |
_DATA_PARALLEL_GROUP = None | |
_MODEL_PARALLEL_GROUP = None | |
rank = dist.get_rank() | |
for i in range(model_parallel_size): | |
ranks = range(i, world_size, model_parallel_size) | |
group = torch.distributed.new_group(ranks) | |
if i == (rank % model_parallel_size): | |
_DATA_PARALLEL_GROUP = group | |
for i in range(world_size // model_parallel_size): | |
ranks = range(i * model_parallel_size, | |
(i + 1) * model_parallel_size) | |
group = torch.distributed.new_group(ranks) | |
if i == (rank // model_parallel_size): | |
_MODEL_PARALLEL_GROUP = group | |
def get_data_parallel_group(): | |
"""Get the data parallel group the caller rank belongs to.""" | |
assert _DATA_PARALLEL_GROUP is not None, \ | |
'data parallel group is not initialized' | |
return _DATA_PARALLEL_GROUP | |
def get_model_parallel_group(): | |
"""Get the model parallel group the caller rank belongs to.""" | |
assert _MODEL_PARALLEL_GROUP is not None, \ | |
'model parallel group is not initialized' | |
return _MODEL_PARALLEL_GROUP | |
def get_model_parallel_rank(): | |
"""Return my rank for the model parallel group.""" | |
return torch.distributed.get_rank(group=get_model_parallel_group()) | |
def ag_test(): | |
src_rank = get_model_parallel_rank() | |
mats = [] | |
for _ in range(dist.get_world_size(get_data_parallel_group())): | |
mats.append(torch.rand(1,268*1024*1024//dist.get_world_size(get_data_parallel_group()), device=device)) | |
dist.all_gather(mats, mats[dist.get_rank(get_data_parallel_group())], group=get_data_parallel_group()) | |
for i in range(100): | |
if rank == 0: | |
print(f"round {i}") | |
ag_test() | |
if rank == 0: | |
print("Done") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment