Skip to content

Instantly share code, notes, and snippets.

@ayulockin
Last active August 18, 2023 11:38
Show Gist options
  • Select an option

  • Save ayulockin/d95cc595158f3d2b3a3e290b524f4421 to your computer and use it in GitHub Desktop.

Select an option

Save ayulockin/d95cc595158f3d2b3a3e290b524f4421 to your computer and use it in GitHub Desktop.
Configure lit-gpt train function to start logging metrics to W&B
# Demonstrative changes to this file: https://github.com/Lightning-AI/lit-gpt/blob/main/finetune/lora.py
from lightning.pytorch.loggers import WandbLogger
def setup(
data_dir: Path = Path("data/alpaca"),
checkpoint_dir: Path = Path("checkpoints/stabilityai/stablelm-base-alpha-3b"),
out_dir: Path = Path("out/lora/alpaca"),
precision: Optional[str] = None,
tpu: bool = False,
):
precision = precision or get_default_supported_precision(training=True, tpu=tpu)
fabric_devices = devices
if fabric_devices > 1:
if tpu:
# For multi-host TPU training, the device count for Fabric is limited to the count on a single host.
fabric_devices = "auto"
strategy = XLAStrategy(sync_module_states=False)
else:
strategy = FSDPStrategy(
auto_wrap_policy={Block},
activation_checkpointing_policy={Block},
state_dict_type="full",
limit_all_gathers=True,
)
else:
strategy = "auto"
logger = step_csv_logger(out_dir.parent, out_dir.name, flush_logs_every_n_steps=log_interval)
# Initialize W&B
wandb_logger = WandbLogger(project="lit-gpt-neurips")
# This is where Fabric is initialized
fabric = L.Fabric(devices=fabric_devices, strategy=strategy, precision=precision, loggers=[logger,wandb_logger])
fabric.print(hparams)
fabric.launch(main, data_dir, checkpoint_dir, out_dir)
def train(
fabric: L.Fabric,
model: GPT,
optimizer: torch.optim.Optimizer,
train_data: List[Dict],
val_data: List[Dict],
checkpoint_dir: Path,
out_dir: Path,
speed_monitor: SpeedMonitor,
) -> None:
tokenizer = Tokenizer(checkpoint_dir)
max_seq_length, longest_seq_length, longest_seq_ix = get_max_seq_length(train_data)
validate(fabric, model, val_data, tokenizer, longest_seq_length) # sanity check
with torch.device("meta"):
meta_model = GPT(model.config)
mark_only_lora_as_trainable(meta_model)
# "estimated" is not as precise as "measured". Estimated is optimistic but widely used in the wild.
# When comparing MFU or FLOP numbers with other projects that use estimated FLOPs,
# consider passing `SpeedMonitor(flops_per_batch=estimated_flops)` instead
estimated_flops = estimate_flops(meta_model) * micro_batch_size
fabric.print(f"Estimated TFLOPs: {estimated_flops * fabric.world_size / 1e12:.2f}")
# this assumes that all samples have a fixed length equal to the longest sequence length
# which is most likely false during finetuning
x = torch.randint(0, 1, (micro_batch_size, longest_seq_length))
measured_flops = measure_flops(meta_model, x)
fabric.print(f"Measured TFLOPs: {measured_flops * fabric.world_size / 1e12:.2f}")
del meta_model, x
print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB", file=sys.stderr)
step_count = 0
total_lengths = 0
total_t0 = time.perf_counter()
if fabric.device.type == "xla":
import torch_xla.core.xla_model as xm
xm.mark_step()
for iter_num in range(max_iters):
if step_count <= warmup_steps:
# linear warmup
lr = learning_rate * step_count / warmup_steps
for param_group in optimizer.param_groups:
param_group["lr"] = lr
iter_t0 = time.perf_counter()
input_ids, targets = get_batch(
fabric, train_data, longest_seq_length, longest_seq_ix if iter_num == 0 else None
)
is_accumulating = (iter_num + 1) % gradient_accumulation_iters != 0
with fabric.no_backward_sync(model, enabled=is_accumulating):
logits = model(input_ids, max_seq_length=max_seq_length, lm_head_chunk_size=128)
# shift the targets such that output n predicts token n+1
logits[-1] = logits[-1][..., :-1, :]
loss = chunked_cross_entropy(logits, targets[..., 1:])
fabric.backward(loss / gradient_accumulation_iters)
if not is_accumulating:
optimizer.step()
optimizer.zero_grad()
step_count += 1
elif fabric.device.type == "xla":
xm.mark_step()
t1 = time.perf_counter()
total_lengths += input_ids.size(1)
speed_monitor.on_train_batch_end(
(iter_num + 1) * micro_batch_size,
t1 - total_t0,
# this assumes that device FLOPs are the same and that all devices have the same batch size
fabric.world_size,
flops_per_batch=measured_flops,
lengths=total_lengths,
)
if iter_num % log_interval == 0:
fabric.print(
f"iter {iter_num} step {step_count}: loss {loss.item():.4f}, iter time:"
f" {(t1 - iter_t0) * 1000:.2f}ms{' (optimizer.step)' if not is_accumulating else ''}"
)
# -> W&B Logging
fabric.log_dict(
"iter": iter_num,
"step": step_count,
"loss": loss.item(),
"iter time (ms)": (t1 - iter_t0) * 1000,
)
if not is_accumulating and step_count % eval_interval == 0:
t0 = time.perf_counter()
val_loss = validate(fabric, model, val_data, tokenizer, longest_seq_length)
t1 = time.perf_counter() - t0
speed_monitor.eval_end(t1)
fabric.print(f"step {iter_num}: val loss {val_loss:.4f}, val time: {t1 * 1000:.2f}ms")
# -> W&B Logging
fabric.log_dict(
"val_iter": iter_num,
"val_loss": val_loss,
"val_time (ms)": t1 * 1000,
)
fabric.barrier()
if not is_accumulating and step_count % save_interval == 0:
checkpoint_path = out_dir / f"iter-{iter_num:06d}-ckpt.pth"
save_lora_checkpoint(fabric, model, checkpoint_path)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment