Last active
November 4, 2024 20:58
-
-
Save YouJiacheng/e6b65d32fe3197af993e6cafcbc6e56c to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import sys | |
import torch._dynamo.compiled_autograd | |
with open(sys.argv[0]) as f: | |
code = f.read() # read the code of this file ASAP, for logging | |
import uuid | |
import glob | |
import time | |
from dataclasses import dataclass | |
import numpy as np | |
import torch | |
from torch import nn | |
import torch.nn.functional as F | |
import torch.distributed as dist | |
import torch._inductor.config as config | |
from torch.nn.parallel import DistributedDataParallel as DDP | |
from collections import deque | |
from typing import cast | |
# ----------------------------------------------------------------------------- | |
# Muon optimizer | |
def zeropower_via_svd(G, steps=None): | |
U, S, V = G.svd() | |
return U @ V.T | |
def zeropower_via_newtonschulz5(G, steps=10, eps=1e-7): | |
""" | |
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a | |
quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose | |
of minimizing steps, it turns out to be empirically effective to keep increasing the slope at | |
zero even beyond the point where the iteration no longer converges all the way to one everywhere | |
on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T | |
where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model | |
performance at all relative to UV^T, where USV^T = G is the SVD. | |
""" | |
assert len(G.shape) == 2 | |
a, b, c = (3.4445, -4.7750, 2.0315) | |
X = G.bfloat16() | |
X /= X.norm() + eps # ensure top singular value <= 1 | |
if G.size(0) > G.size(1): | |
X = X.T | |
for _ in range(steps): | |
A = X @ X.T | |
B = A @ X | |
X = a * X + b * B + c * A @ B | |
if G.size(0) > G.size(1): | |
X = X.T | |
return X | |
zeropower_backends = dict( | |
svd=zeropower_via_svd, newtonschulz5=zeropower_via_newtonschulz5 | |
) | |
config.triton.cudagraph_support_input_mutation = True | |
@torch.compile(mode="max-autotune-no-cudagraphs", fullgraph=True) | |
def compute_updates( | |
params_shard: list[torch.Tensor], | |
bufs_shard: list[torch.Tensor], | |
slice_list_shard: list[slice], | |
scale_list_shard: list[float], | |
params: list[torch.Tensor], | |
slice_list: list[slice], | |
momentum: float, | |
): | |
n = sum(p.numel() for p in params) | |
updates_full_flat = torch.zeros(n, dtype=torch.bfloat16, device="cuda") | |
# assert all(p.grad is not None for p in self.params_shard) | |
grads_shard = [cast(torch.Tensor, p.grad) for p in params_shard] | |
torch._foreach_mul_(bufs_shard, momentum) | |
torch._foreach_add_(bufs_shard, grads_shard) | |
# avoid mutating inputs from eager | |
vs = torch._foreach_add(grads_shard, bufs_shard, alpha=momentum) | |
update_views_shard = [updates_full_flat[s] for s in slice_list_shard] | |
for u, s, v in zip(update_views_shard, scale_list_shard, vs): | |
torch.mul(zeropower_via_newtonschulz5(v, steps=5).flatten(), s, out=u) | |
# sync updates across devices. we are not memory-constrained so can do this simple deserialization | |
dist.all_reduce(updates_full_flat, op=dist.ReduceOp.SUM) | |
return [updates_full_flat[s].view_as(p) for s, p in zip(slice_list, params)] | |
class Muon(torch.optim.Optimizer): | |
def __init__( | |
self, | |
params, | |
lr: float | torch.Tensor = 0.02, | |
momentum=0.95, | |
nesterov=True, | |
backend="newtonschulz5", | |
backend_steps=5, | |
): | |
assert nesterov | |
defaults = dict( | |
lr=lr, | |
momentum=momentum, | |
nesterov=nesterov, | |
backend=backend, | |
backend_steps=backend_steps, | |
) | |
super().__init__(params, defaults) | |
assert len(self.param_groups) == 1 | |
group = self.param_groups[0] | |
self.params_shard: list[torch.Tensor] = [] | |
self.momentum_buffer_list_shard: list[torch.Tensor] = [] | |
self.slice_list: list[slice] = [] | |
self.slice_list_shard: list[slice] = [] | |
self.scale_list_shard: list[float] = [] | |
offset = 0 | |
for i, p in enumerate(group["params"]): | |
assert isinstance(p, torch.Tensor) | |
_slice = slice(offset, offset + p.numel()) | |
self.slice_list.append(_slice) | |
if i % int(os.environ["WORLD_SIZE"]) == int(os.environ["RANK"]): | |
self.params_shard.append(p) | |
buf = torch.zeros_like(p) | |
torch._dynamo.mark_static_address(buf) | |
self.momentum_buffer_list_shard.append(buf) | |
self.state[p]["momentum_buffer"] = buf | |
self.slice_list_shard.append(_slice) | |
self.scale_list_shard.append(max(1, p.size(0) / p.size(1)) ** 0.5) | |
offset += p.numel() | |
# Tensor LR is slower than excluding lr from the compiled function | |
@torch.no_grad() | |
def step(self): | |
group = self.param_groups[0] | |
update_views = compute_updates( | |
self.params_shard, | |
self.momentum_buffer_list_shard, | |
self.slice_list_shard, | |
self.scale_list_shard, | |
group["params"], | |
self.slice_list, | |
group["momentum"], | |
) | |
# apply updates | |
torch._foreach_add_(group["params"], update_views, alpha=-group["lr"]) | |
# ----------------------------------------------------------------------------- | |
# PyTorch nn.Module definitions for the GPT-2 model | |
class Rotary(torch.nn.Module): | |
def __init__(self, dim, base=10000): | |
super().__init__() | |
self.inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) | |
self.seq_len_cached = None | |
self.cos_cached = None | |
self.sin_cached = None | |
def forward(self, x): | |
seq_len = x.shape[1] | |
if seq_len != self.seq_len_cached: | |
self.seq_len_cached = seq_len | |
t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq) | |
freqs = torch.outer(t, self.inv_freq).to(x.device) | |
self.cos_cached = freqs.cos().bfloat16() | |
self.sin_cached = freqs.sin().bfloat16() | |
return self.cos_cached[None, :, None, :], self.sin_cached[None, :, None, :] | |
def apply_rotary_emb(x, cos, sin): | |
assert x.ndim == 4 # multihead attention | |
d = x.shape[3] // 2 | |
x1 = x[..., :d] | |
x2 = x[..., d:] | |
y1 = x1 * cos + x2 * sin | |
y2 = x1 * (-sin) + x2 * cos | |
return torch.cat([y1, y2], 3).type_as(x) | |
class CausalSelfAttention(nn.Module): | |
def __init__(self, config): | |
super().__init__() | |
self.n_head = config.n_head | |
self.n_embd = config.n_embd | |
self.head_dim = self.n_embd // self.n_head | |
assert self.n_embd % self.n_head == 0 | |
self.c_q = nn.Linear(self.n_embd, self.n_embd, bias=False) | |
self.c_k = nn.Linear(self.n_embd, self.n_embd, bias=False) | |
self.c_v = nn.Linear(self.n_embd, self.n_embd, bias=False) | |
# output projection | |
self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False) | |
self.c_proj.weight.data.zero_() # zero init suggested by @Grad62304977 | |
self.rotary = Rotary(self.head_dim) | |
def forward(self, x): | |
B, T, C = ( | |
x.size() | |
) # batch size, sequence length, embedding dimensionality (n_embd) | |
q = self.c_q(x).view(B, T, self.n_head, self.head_dim) | |
k = self.c_k(x).view(B, T, self.n_head, self.head_dim) | |
v = self.c_v(x).view(B, T, self.n_head, self.head_dim) | |
cos, sin = self.rotary(q) | |
q, k = ( | |
F.rms_norm(q, (q.size(-1),)), | |
F.rms_norm(k, (k.size(-1),)), | |
) # QK norm suggested by @Grad62304977 | |
q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin) | |
y = F.scaled_dot_product_attention( | |
q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=True | |
) | |
y = ( | |
y.transpose(1, 2).contiguous().view_as(x) | |
) # re-assemble all head outputs side by side | |
y = self.c_proj(y) | |
return y | |
class MLP(nn.Module): | |
def __init__(self, config): | |
super().__init__() | |
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False) | |
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False) | |
self.c_proj.weight.data.zero_() # zero init suggested by @Grad62304977 | |
def forward(self, x): | |
x = self.c_fc(x) | |
x = F.relu( | |
x | |
).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 | |
x = self.c_proj(x) | |
return x | |
class Block(nn.Module): | |
def __init__(self, config): | |
super().__init__() | |
self.attn = CausalSelfAttention(config) | |
self.mlp = MLP(config) | |
def forward(self, x): | |
x = x + self.attn(F.rms_norm(x, (x.size(-1),))) | |
x = x + self.mlp(F.rms_norm(x, (x.size(-1),))) | |
return x | |
# ----------------------------------------------------------------------------- | |
# The main GPT-2 model | |
@dataclass | |
class GPTConfig: | |
vocab_size: int = 50304 | |
n_layer: int = 12 | |
n_head: int = 6 # head dim 128 suggested by @Grad62304977 | |
n_embd: int = 768 | |
class GPT(nn.Module): | |
def __init__(self, config): | |
super().__init__() | |
self.config = config | |
self.transformer = nn.ModuleDict( | |
dict( | |
wte=nn.Embedding(config.vocab_size, config.n_embd), | |
h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]), | |
) | |
) | |
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) | |
self.lm_head.weight.data.zero_() | |
def forward(self, idx, targets=None, return_logits=True): | |
# forward the GPT model itself | |
x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) | |
x = F.rms_norm(x, (x.size(-1),)) | |
for block in self.transformer.h: | |
x = block(x) | |
x = F.rms_norm(x, (x.size(-1),)) | |
if targets is not None: | |
# if we are given some desired targets also calculate the loss | |
logits = self.lm_head(x) | |
logits = logits.float() # use tf32/fp32 for logits | |
loss = F.cross_entropy( | |
logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1 | |
) | |
else: | |
# inference-time mini-optimization: only forward the lm_head on the very last position | |
logits = self.lm_head( | |
x[:, [-1], :] | |
) # note: using list [-1] to preserve the time dim | |
logits = logits.float() # use tf32/fp32 for logits | |
loss = None | |
# there are performance reasons why not returning logits is prudent, if not needed | |
if not return_logits: | |
logits = None | |
return logits, loss | |
# ----------------------------------------------------------------------------- | |
# Our own simple Distributed Data Loader | |
def _peek_data_shard(filename): | |
# only reads the header, returns header data | |
with open(filename, "rb") as f: | |
# first read the header, which is 256 int32 integers (4 bytes each) | |
header = np.frombuffer(f.read(256 * 4), dtype=np.int32) | |
if header[0] != 20240520: | |
print("ERROR: magic number mismatch in the data .bin file!") | |
print("---> HINT: Are you passing in a correct file with --input_bin?") | |
print( | |
"---> HINT: Dataset encoding changed recently, re-run data prepro or refer again to README" | |
) | |
print( | |
"---> HINT: For example re-run: `python dev/data/tinyshakespeare.py`, then re-try" | |
) | |
exit(1) | |
assert header[1] == 1, "unsupported version" | |
ntok = header[2] # number of tokens (claimed) | |
return ntok # for now just return the number of tokens | |
def _load_data_shard(filename): | |
with open(filename, "rb") as f: | |
# first read the header, which is 256 int32 integers (4 bytes each) | |
header = np.frombuffer(f.read(256 * 4), dtype=np.int32) | |
assert header[0] == 20240520, "magic number mismatch in the data .bin file" | |
assert header[1] == 1, "unsupported version" | |
ntok = header[2] # number of tokens (claimed) | |
# the rest of it are tokens, stored as uint16 | |
tokens = np.frombuffer(f.read(), dtype=np.uint16) | |
assert len(tokens) == ntok, "number of tokens read does not match header?" | |
return tokens | |
class DistributedDataLoader: | |
def __init__(self, filename_pattern, B, T, process_rank, num_processes): | |
self.process_rank = process_rank | |
self.num_processes = num_processes | |
self.B = B | |
self.T = T | |
# glob files that match the pattern | |
self.files = sorted(glob.glob(filename_pattern)) | |
assert ( | |
len(self.files) > 0 | |
), f"did not find any files that match the pattern {filename_pattern}" | |
# load and validate all data shards, count number of tokens in total | |
ntok_total = 0 | |
for fname in self.files: | |
shard_ntok = _peek_data_shard(fname) | |
assert shard_ntok >= num_processes * B * T + 1 | |
ntok_total += int(shard_ntok) | |
self.ntok_total = ntok_total | |
# kick things off | |
self.reset() | |
def reset(self): | |
self.current_shard = 0 | |
self.current_position = self.process_rank * self.B * self.T | |
self.tokens = _load_data_shard(self.files[self.current_shard]) | |
def advance(self): # advance to next data shard | |
self.current_shard = (self.current_shard + 1) % len(self.files) | |
self.current_position = self.process_rank * self.B * self.T | |
self.tokens = _load_data_shard(self.files[self.current_shard]) | |
def next_batch(self): | |
B = self.B | |
T = self.T | |
buf = self.tokens[self.current_position : self.current_position + B * T + 1] | |
buf = torch.tensor(buf.astype(np.int32), dtype=torch.long) | |
x = (buf[:-1]).view(B, T) # inputs | |
y = (buf[1:]).view(B, T) # targets | |
# advance current position and load next shard if necessary | |
self.current_position += B * T * self.num_processes | |
if self.current_position + (B * T * self.num_processes + 1) > len(self.tokens): | |
self.advance() | |
return x.cuda(), y.cuda() | |
# ----------------------------------------------------------------------------- | |
# int main | |
@dataclass | |
class Hyperparameters: | |
# data hyperparams | |
input_bin: str = "data/fineweb10B/fineweb_train_*.bin" # input .bin to train on | |
input_val_bin: str = ( | |
"data/fineweb10B/fineweb_val_*.bin" # input .bin to eval validation loss on | |
) | |
# optimization hyperparams | |
batch_size: int = 8 * 64 # batch size, in sequences, across all devices | |
device_batch_size: int = 64 # batch size, in sequences, per device | |
sequence_length: int = 1024 # sequence length, in tokens | |
num_iterations: int = 4578 # number of iterations to run | |
warmup_iters: int = 0 | |
warmdown_iters: int = 1308 # number of iterations of linear warmup/warmdown for triangular or trapezoidal schedule | |
weight_decay: float = 0 | |
# evaluation and logging hyperparams | |
val_loss_every: int = ( | |
0 # every how many steps to evaluate val loss? 0 for only at the end | |
) | |
val_tokens: int = 10485760 # how many tokens of validation data? it's important to keep this fixed for consistent comparisons | |
save_every: int = ( | |
0 # every how many steps to save the checkpoint? 0 for only at the end | |
) | |
args = Hyperparameters() | |
# set up DDP (distributed data parallel). torchrun sets this env variable | |
assert torch.cuda.is_available() | |
dist.init_process_group(backend="nccl") | |
ddp_rank = int(os.environ["RANK"]) | |
ddp_local_rank = int(os.environ["LOCAL_RANK"]) | |
ddp_world_size = int(os.environ["WORLD_SIZE"]) | |
device = f"cuda:{ddp_local_rank}" | |
torch.cuda.set_device(device) | |
print(f"using device: {device}") | |
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc. | |
# convenience variables | |
B, T = args.device_batch_size, args.sequence_length | |
# calculate the number of steps to take in the val loop. | |
assert args.val_tokens % (B * T * ddp_world_size) == 0 | |
val_steps = args.val_tokens // (B * T * ddp_world_size) | |
# calculate the steps of gradient accumulation required to attain the desired global batch size. | |
assert args.batch_size % (B * ddp_world_size) == 0 | |
train_accumulation_steps = args.batch_size // (B * ddp_world_size) | |
# load tokens | |
train_loader = DistributedDataLoader(args.input_bin, B, T, ddp_rank, ddp_world_size) | |
val_loader = DistributedDataLoader(args.input_val_bin, B, T, ddp_rank, ddp_world_size) | |
if master_process: | |
print( | |
f"Training DataLoader: total number of tokens: {train_loader.ntok_total} across {len(train_loader.files)} files" | |
) | |
print( | |
f"Validation DataLoader: total number of tokens: {val_loader.ntok_total} across {len(val_loader.files)} files" | |
) | |
torch._logging.set_logs(recompiles=True) | |
x, y = train_loader.next_batch() | |
# there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. suggested to me by @Grad62304977. | |
# this originates from Karpathy's experiments. | |
num_vocab = 50304 | |
model = GPT(GPTConfig(vocab_size=num_vocab, n_layer=12, n_head=6, n_embd=768)) | |
model = model.cuda() | |
if hasattr(config, "coordinate_descent_tuning"): | |
config.coordinate_descent_tuning = True # suggested by @Chillee | |
model = torch.compile(model) | |
# here we wrap model into DDP container | |
model = DDP(model, device_ids=[ddp_local_rank]) | |
raw_model = model.module # always contains the "raw" unwrapped model | |
ctx = torch.autocast(device_type="cuda", dtype=torch.bfloat16) | |
# init the optimizer(s) | |
optimizer1 = torch.optim.Adam( | |
[raw_model.transformer.wte.weight], lr=0.3, betas=(0.9, 0.95), fused=True | |
) | |
optimizer2 = torch.optim.Adam( | |
[raw_model.lm_head.weight], lr=0.003, betas=(0.9, 0.95), fused=True | |
) | |
optimizer3 = Muon(raw_model.transformer.h.parameters(), lr=0.02, momentum=0.95) | |
optimizers = [optimizer1, optimizer2, optimizer3] | |
# learning rate decay scheduler (linear warmup and warmdown) | |
def get_lr(it): | |
assert it <= args.num_iterations | |
# 1) linear warmup for warmup_iters steps | |
if it < args.warmup_iters: | |
return (it + 1) / args.warmup_iters | |
# 2) constant lr for a while | |
elif it < args.num_iterations - args.warmdown_iters: | |
return 1.0 | |
# 3) linear warmdown | |
else: | |
decay_ratio = (args.num_iterations - it) / args.warmdown_iters | |
return decay_ratio | |
schedulers = [torch.optim.lr_scheduler.LambdaLR(opt, get_lr) for opt in optimizers] | |
# begin logging | |
if master_process: | |
run_id = str(uuid.uuid4()) | |
logdir = "logs/%s/" % run_id | |
os.makedirs(logdir, exist_ok=True) | |
logfile = "logs/%s.txt" % run_id | |
# create the log file | |
with open(logfile, "w") as f: | |
# begin the log by printing this file (the Python code) | |
f.write("=" * 100 + "\n") | |
f.write(code) | |
f.write("=" * 100 + "\n") | |
# log information about the hardware/software environment this is running on | |
# and print the full `nvidia-smi` to file | |
f.write( | |
f"Running pytorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}\nnvidia-smi:\n" | |
) | |
import subprocess | |
result = subprocess.run( | |
["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True | |
) | |
f.write(f"{result.stdout}\n") | |
f.write("=" * 100 + "\n") | |
training_time_ms = 0 | |
# start the clock | |
torch.cuda.synchronize() | |
t0 = time.time() | |
# begin training | |
train_loader.reset() | |
step_time_ms_window = deque(maxlen=100) | |
end_of_last_step = t0 | |
torch._dynamo.config.compiled_autograd = True | |
# torch._logging.set_logs(recompiles=True) | |
@torch.compile | |
def fwd_bwd(model, x, y): | |
with ctx: | |
_, loss = model(x, y, return_logits=False) | |
train_loss = loss.detach() | |
loss.backward() | |
return train_loss | |
for step in range(args.num_iterations + 1): | |
last_step = step == args.num_iterations | |
# This effectively ignores timing first 10 steps, which are slower for weird reasons. | |
# Alternately, and slightly more correctly in terms of benchmarking, we could do 10 | |
# steps with dummy data first, and then re-initialize the model and reset the loader. | |
if step == 10: | |
training_time_ms = 0 | |
t0 = time.time() | |
timed_steps = ( | |
float("nan") if step <= 11 else (step - 10) + 1 | |
) # <= 11 to avoid bug in val | |
# once in a while evaluate the validation dataset | |
if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0): | |
# stop the clock | |
torch.cuda.synchronize() | |
training_time_ms += 1000 * (time.time() - t0) | |
# run validation batches | |
model.eval() | |
val_loader.reset() | |
val_loss = 0.0 | |
for _ in range(val_steps): | |
x_val, y_val = val_loader.next_batch() | |
with ctx: # of course, we'd like to use no_grad() here too, but that creates a torch.compile error for some reason | |
_, loss = model(x_val, y_val, return_logits=False) | |
val_loss += loss.detach() | |
del loss | |
dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) | |
val_loss /= val_steps | |
# log val loss to console and to logfile | |
if master_process: | |
print( | |
f"step:{step}/{args.num_iterations} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/(timed_steps-1):.2f}ms" | |
) | |
with open(logfile, "a") as f: | |
f.write( | |
f"step:{step}/{args.num_iterations} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/(timed_steps-1):.2f}ms\n" | |
) | |
# start the clock again | |
torch.cuda.synchronize() | |
t0 = time.time() | |
if master_process and ( | |
last_step or (args.save_every > 0 and step % args.save_every == 0) | |
): | |
# stop the clock | |
torch.cuda.synchronize() | |
training_time_ms += 1000 * (time.time() - t0) | |
# save the state of the training process | |
log = dict( | |
step=step, | |
code=code, | |
model=raw_model.state_dict(), | |
optimizers=[opt.state_dict() for opt in optimizers], | |
) | |
torch.save(log, "logs/%s/state_step%06d.pt" % (run_id, step)) | |
# start the clock again | |
torch.cuda.synchronize() | |
t0 = time.time() | |
# bit confusing: we want to make sure to eval on 0th iteration | |
# but also after the very last iteration. so we loop for step <= num_iterations | |
# instead of just < num_iterations (one extra due to <=), only to do | |
# the validation/sampling one last time, and then we break right here as we're done. | |
if last_step: | |
break | |
# --------------- TRAINING SECTION BEGIN ----------------- | |
model.train() | |
assert train_accumulation_steps == 1 | |
# forward & backward pass | |
train_loss = fwd_bwd(model, x, y) | |
# advance the dataset for the next batch | |
x, y = train_loader.next_batch() | |
# step the optimizers and schedulers | |
for opt, sched in zip(optimizers, schedulers): | |
opt.step() | |
sched.step() | |
# null the gradients | |
model.zero_grad(set_to_none=True) | |
# --------------- TRAINING SECTION END ------------------- | |
# everything that follows now is just diagnostics, prints, logging, etc. | |
# dist.all_reduce(train_loss, op=dist.ReduceOp.AVG) # all-reducing the training loss would be more correct in terms of logging, but slower | |
if master_process: | |
end_of_step = time.time() | |
step_time_ms_window.append(1000 * (end_of_step - end_of_last_step)) | |
end_of_last_step = end_of_step | |
window_mid = sorted(step_time_ms_window)[20:-20] | |
if window_mid: | |
window_average = sum(window_mid) / len(window_mid) | |
else: | |
window_average = float("nan") | |
approx_time = training_time_ms + 1000 * (end_of_step - t0) | |
print( | |
f"step:{step+1}/{args.num_iterations} train_loss:{train_loss.item():.4f} train_time:{approx_time:.0f}ms step_avg:{approx_time/timed_steps:.2f}ms" | |
f" window_avg:{window_average:.2f}ms" | |
) | |
with open(logfile, "a") as f: | |
f.write( | |
f"step:{step+1}/{args.num_iterations} train_loss:{train_loss.item():.4f} train_time:{approx_time:.0f}ms step_avg:{approx_time/timed_steps:.2f}ms" | |
f" window_avg:{window_average:.2f}ms" | |
"\n" | |
) | |
if master_process: | |
print( | |
f"peak memory consumption: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB" | |
) | |
# ------------------------------------------------------------------------- | |
# clean up nice | |
dist.destroy_process_group() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment