Created
June 16, 2023 12:00
-
-
Save chenyaofo/da95309287665e85daeeef249bdad3ce to your computer and use it in GitHub Desktop.
Deep Learning Engine
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pathlib | |
import loguru | |
import dataclasses | |
import deepspeed | |
import torch4x | |
from deepspeed import comm as dist | |
import pprint | |
import pyhocon | |
import torch.utils.collect_env as collect_env | |
import typing | |
import sys | |
import pathlib | |
from dataclasses import dataclass | |
from pyhocon import ConfigFactory, ConfigTree | |
from torch4x.typed_args import TypedArgs, add_argument | |
from torch4x.register import REGISTRY | |
import torch.nn as nn | |
from torch.utils.tensorboard import SummaryWriter | |
import torchmetrics | |
from deepspeed.runtime.engine import DeepSpeedEngine | |
def apply_modifications(modifications: typing.Sequence[str], conf: ConfigTree): | |
special_cases = { | |
"true": True, | |
"false": False | |
} | |
if modifications is not None: | |
for modification in modifications: | |
key, value = modification.split("=") | |
if value in special_cases.keys(): | |
eval_value = special_cases[value] | |
else: | |
try: | |
eval_value = eval(value) | |
except Exception: | |
eval_value = value | |
if key not in conf: | |
raise ValueError(f"Key '{key}'' is not in the config tree!") | |
conf.put(key, eval_value) | |
return conf | |
@dataclass | |
class Args(TypedArgs): | |
outdir: str = add_argument("-o", default="") | |
conf: str = add_argument("-c", default="") | |
modifications: typing.Sequence[str] = add_argument("-M", nargs='+', help="list") | |
def get_args(argv=sys.argv): | |
args, _ = Args.from_known_args(argv) | |
args.outdir = pathlib.Path(args.outdir) | |
args.conf: ConfigTree = ConfigFactory.parse_file(args.conf) | |
apply_modifications(modifications=args.modifications, conf=args.conf) | |
return args | |
class BaseRunner: | |
def __init__(self, args=get_args()) -> None: | |
self.args: Args = args | |
self.conf: pyhocon.ConfigTree = args.conf | |
self.outdir: pathlib.Path = args.outdir | |
self.default_logname = "default.log" | |
self._init_for_training() | |
def _init_for_training(self): | |
self._init_logging() | |
torch4x.set_active_device() | |
self._init_deepspeed_distributed() | |
self._create_code_snapshot() | |
self._logging_system_info() | |
def _init_logging(self): | |
if torch4x.is_rank_0: | |
loguru.logger.add(self.outdir / self.default_logname) | |
else: | |
loguru.logger.remove() | |
def _init_deepspeed_distributed(): | |
deepspeed.init_distributed() | |
def _create_code_snapshot( | |
self, | |
name="code", | |
include_suffix=[".py", ".conf"], | |
source_directory="." | |
): | |
torch4x.create_code_snapshot( | |
name=name, | |
include_suffix=include_suffix, | |
source_directory=source_directory, | |
store_directory=self.outdir | |
) | |
def _logging_system_info(self): | |
loguru.logger.info("Collect envs from system:\n" + collect_env.get_pretty_env_info()) | |
loguru.logger.info("Args:\n" + pprint.pformat(dataclasses.asdict(self.args))) | |
class DefaultRunner(BaseRunner): | |
def _build_model(self): | |
self.model = REGISTRY.build_from(self.conf.get("model"), dict(ds_config=self.conf.get("ds_config"))) | |
def _build_dataloader(self): | |
self.train_dataloader, self.eval_dataloader = REGISTRY.build_from(self.conf.get("data")) | |
def _build_optimizer(self): | |
self.model: nn.Module | |
self.optimizer = REGISTRY.build_from( | |
self.conf.get("optimizer"), | |
dict(params=[p for p in self.model.parameters() if p.requires_grad]) | |
) | |
def _build_lr_scheduler(self): | |
self.lr_scheduler = REGISTRY.build_from( | |
self.conf.get("lr_scheduler"), | |
dict(optimizer=self.optimizer) | |
) | |
def _build_tbwriter(self): | |
self.tbwriter = torch4x.only_rank_0(SummaryWriter(self.outdir)) | |
def prepare_for_training(self): | |
self._build_model() | |
self._build_dataloader() | |
self._build_optimizer() | |
self._build_lr_scheduler() | |
self._build_tbwriter() | |
self.model_engine, self.optimizer, _, self.lr_scheduler = deepspeed.initialize( | |
model=self.model, | |
optimizer=self.optimizer, | |
config=self.conf.get("ds_config"), | |
lr_scheduler=self.lr_scheduler | |
) | |
def throughput(self): | |
self.model_engine: DeepSpeedEngine | |
single_node_throughput = self.model_engine.tput_timer.avg_samples_per_sec() | |
return single_node_throughput * torch4x.world_size() | |
def register_metrics(self): | |
self.metrics = [ | |
] | |
runner = BaseRunner("output") | |
loguru.logger.info("Test") | |
with loguru.logger.catch(): | |
a = 1/0 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment