Created
July 9, 2019 02:41
-
-
Save chenyaofo/ee12f04492dac29c5c84b4b6170edefc to your computer and use it in GitHub Desktop.
dist util
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import flame | |
import torch.distributed as dist | |
def init(backend="nccl", init_method="env://"): | |
if "RANK" in os.environ and "WORLD_SIZE" in os.environ: | |
if dist.is_available(): | |
rank = int(os.environ["RANK"]) | |
local_rank = int(os.environ["LOCAL_RANK"]) | |
world_size = int(os.environ['WORLD_SIZE']) | |
master_addr = os.environ["MASTER_ADDR"] | |
master_port = os.environ["MASTER_PORT"] | |
dist.init_process_group(backend=backend, | |
init_method=init_method, | |
world_size=world_size, | |
rank=rank) | |
flame.logger.info("Init distributed mode(backend={}, init_mothod={}:{}, world_size={}).".format( | |
backend, master_addr, master_port, world_size | |
)) | |
return backend, init_method, rank, local_rank, world_size, master_addr, master_port | |
else: | |
flame.logger.error("Fail to init distributed because torch.distributed is unavailable.") | |
return None, None, 0, 0, 1, None, None | |
@property | |
def is_dist_avail_and_init(): | |
return dist.is_available() and dist.is_initialized() | |
@property | |
def rank(): | |
return dist.get_rank() if is_dist_avail_and_init else 0 | |
@property | |
def local_rank(): | |
return int(os.environ["LOCAL_RANK"]) if is_dist_avail_and_init else 0 | |
@property | |
def world_size(): | |
return dist.get_world_size() if is_dist_avail_and_init else 1 | |
@property | |
def is_master(): | |
return rank == 0 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment