Skip to content

Instantly share code, notes, and snippets.

@chenyaofo
Created June 16, 2023 12:00
Show Gist options
  • Save chenyaofo/da95309287665e85daeeef249bdad3ce to your computer and use it in GitHub Desktop.
Save chenyaofo/da95309287665e85daeeef249bdad3ce to your computer and use it in GitHub Desktop.
Deep Learning Engine
import pathlib
import loguru
import dataclasses
import deepspeed
import torch4x
from deepspeed import comm as dist
import pprint
import pyhocon
import torch.utils.collect_env as collect_env
import typing
import sys
import pathlib
from dataclasses import dataclass
from pyhocon import ConfigFactory, ConfigTree
from torch4x.typed_args import TypedArgs, add_argument
from torch4x.register import REGISTRY
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
import torchmetrics
from deepspeed.runtime.engine import DeepSpeedEngine
def apply_modifications(modifications: typing.Sequence[str], conf: ConfigTree):
special_cases = {
"true": True,
"false": False
}
if modifications is not None:
for modification in modifications:
key, value = modification.split("=")
if value in special_cases.keys():
eval_value = special_cases[value]
else:
try:
eval_value = eval(value)
except Exception:
eval_value = value
if key not in conf:
raise ValueError(f"Key '{key}'' is not in the config tree!")
conf.put(key, eval_value)
return conf
@dataclass
class Args(TypedArgs):
outdir: str = add_argument("-o", default="")
conf: str = add_argument("-c", default="")
modifications: typing.Sequence[str] = add_argument("-M", nargs='+', help="list")
def get_args(argv=sys.argv):
args, _ = Args.from_known_args(argv)
args.outdir = pathlib.Path(args.outdir)
args.conf: ConfigTree = ConfigFactory.parse_file(args.conf)
apply_modifications(modifications=args.modifications, conf=args.conf)
return args
class BaseRunner:
def __init__(self, args=get_args()) -> None:
self.args: Args = args
self.conf: pyhocon.ConfigTree = args.conf
self.outdir: pathlib.Path = args.outdir
self.default_logname = "default.log"
self._init_for_training()
def _init_for_training(self):
self._init_logging()
torch4x.set_active_device()
self._init_deepspeed_distributed()
self._create_code_snapshot()
self._logging_system_info()
def _init_logging(self):
if torch4x.is_rank_0:
loguru.logger.add(self.outdir / self.default_logname)
else:
loguru.logger.remove()
def _init_deepspeed_distributed():
deepspeed.init_distributed()
def _create_code_snapshot(
self,
name="code",
include_suffix=[".py", ".conf"],
source_directory="."
):
torch4x.create_code_snapshot(
name=name,
include_suffix=include_suffix,
source_directory=source_directory,
store_directory=self.outdir
)
def _logging_system_info(self):
loguru.logger.info("Collect envs from system:\n" + collect_env.get_pretty_env_info())
loguru.logger.info("Args:\n" + pprint.pformat(dataclasses.asdict(self.args)))
class DefaultRunner(BaseRunner):
def _build_model(self):
self.model = REGISTRY.build_from(self.conf.get("model"), dict(ds_config=self.conf.get("ds_config")))
def _build_dataloader(self):
self.train_dataloader, self.eval_dataloader = REGISTRY.build_from(self.conf.get("data"))
def _build_optimizer(self):
self.model: nn.Module
self.optimizer = REGISTRY.build_from(
self.conf.get("optimizer"),
dict(params=[p for p in self.model.parameters() if p.requires_grad])
)
def _build_lr_scheduler(self):
self.lr_scheduler = REGISTRY.build_from(
self.conf.get("lr_scheduler"),
dict(optimizer=self.optimizer)
)
def _build_tbwriter(self):
self.tbwriter = torch4x.only_rank_0(SummaryWriter(self.outdir))
def prepare_for_training(self):
self._build_model()
self._build_dataloader()
self._build_optimizer()
self._build_lr_scheduler()
self._build_tbwriter()
self.model_engine, self.optimizer, _, self.lr_scheduler = deepspeed.initialize(
model=self.model,
optimizer=self.optimizer,
config=self.conf.get("ds_config"),
lr_scheduler=self.lr_scheduler
)
def throughput(self):
self.model_engine: DeepSpeedEngine
single_node_throughput = self.model_engine.tput_timer.avg_samples_per_sec()
return single_node_throughput * torch4x.world_size()
def register_metrics(self):
self.metrics = [
]
runner = BaseRunner("output")
loguru.logger.info("Test")
with loguru.logger.catch():
a = 1/0
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment