Last active
August 18, 2023 11:38
-
-
Save ayulockin/d95cc595158f3d2b3a3e290b524f4421 to your computer and use it in GitHub Desktop.
Configure lit-gpt train function to start logging metrics to W&B
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # Demonstrative changes to this file: https://github.com/Lightning-AI/lit-gpt/blob/main/finetune/lora.py | |
| from lightning.pytorch.loggers import WandbLogger | |
| def setup( | |
| data_dir: Path = Path("data/alpaca"), | |
| checkpoint_dir: Path = Path("checkpoints/stabilityai/stablelm-base-alpha-3b"), | |
| out_dir: Path = Path("out/lora/alpaca"), | |
| precision: Optional[str] = None, | |
| tpu: bool = False, | |
| ): | |
| precision = precision or get_default_supported_precision(training=True, tpu=tpu) | |
| fabric_devices = devices | |
| if fabric_devices > 1: | |
| if tpu: | |
| # For multi-host TPU training, the device count for Fabric is limited to the count on a single host. | |
| fabric_devices = "auto" | |
| strategy = XLAStrategy(sync_module_states=False) | |
| else: | |
| strategy = FSDPStrategy( | |
| auto_wrap_policy={Block}, | |
| activation_checkpointing_policy={Block}, | |
| state_dict_type="full", | |
| limit_all_gathers=True, | |
| ) | |
| else: | |
| strategy = "auto" | |
| logger = step_csv_logger(out_dir.parent, out_dir.name, flush_logs_every_n_steps=log_interval) | |
| # Initialize W&B | |
| wandb_logger = WandbLogger(project="lit-gpt-neurips") | |
| # This is where Fabric is initialized | |
| fabric = L.Fabric(devices=fabric_devices, strategy=strategy, precision=precision, loggers=[logger,wandb_logger]) | |
| fabric.print(hparams) | |
| fabric.launch(main, data_dir, checkpoint_dir, out_dir) | |
| def train( | |
| fabric: L.Fabric, | |
| model: GPT, | |
| optimizer: torch.optim.Optimizer, | |
| train_data: List[Dict], | |
| val_data: List[Dict], | |
| checkpoint_dir: Path, | |
| out_dir: Path, | |
| speed_monitor: SpeedMonitor, | |
| ) -> None: | |
| tokenizer = Tokenizer(checkpoint_dir) | |
| max_seq_length, longest_seq_length, longest_seq_ix = get_max_seq_length(train_data) | |
| validate(fabric, model, val_data, tokenizer, longest_seq_length) # sanity check | |
| with torch.device("meta"): | |
| meta_model = GPT(model.config) | |
| mark_only_lora_as_trainable(meta_model) | |
| # "estimated" is not as precise as "measured". Estimated is optimistic but widely used in the wild. | |
| # When comparing MFU or FLOP numbers with other projects that use estimated FLOPs, | |
| # consider passing `SpeedMonitor(flops_per_batch=estimated_flops)` instead | |
| estimated_flops = estimate_flops(meta_model) * micro_batch_size | |
| fabric.print(f"Estimated TFLOPs: {estimated_flops * fabric.world_size / 1e12:.2f}") | |
| # this assumes that all samples have a fixed length equal to the longest sequence length | |
| # which is most likely false during finetuning | |
| x = torch.randint(0, 1, (micro_batch_size, longest_seq_length)) | |
| measured_flops = measure_flops(meta_model, x) | |
| fabric.print(f"Measured TFLOPs: {measured_flops * fabric.world_size / 1e12:.2f}") | |
| del meta_model, x | |
| print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB", file=sys.stderr) | |
| step_count = 0 | |
| total_lengths = 0 | |
| total_t0 = time.perf_counter() | |
| if fabric.device.type == "xla": | |
| import torch_xla.core.xla_model as xm | |
| xm.mark_step() | |
| for iter_num in range(max_iters): | |
| if step_count <= warmup_steps: | |
| # linear warmup | |
| lr = learning_rate * step_count / warmup_steps | |
| for param_group in optimizer.param_groups: | |
| param_group["lr"] = lr | |
| iter_t0 = time.perf_counter() | |
| input_ids, targets = get_batch( | |
| fabric, train_data, longest_seq_length, longest_seq_ix if iter_num == 0 else None | |
| ) | |
| is_accumulating = (iter_num + 1) % gradient_accumulation_iters != 0 | |
| with fabric.no_backward_sync(model, enabled=is_accumulating): | |
| logits = model(input_ids, max_seq_length=max_seq_length, lm_head_chunk_size=128) | |
| # shift the targets such that output n predicts token n+1 | |
| logits[-1] = logits[-1][..., :-1, :] | |
| loss = chunked_cross_entropy(logits, targets[..., 1:]) | |
| fabric.backward(loss / gradient_accumulation_iters) | |
| if not is_accumulating: | |
| optimizer.step() | |
| optimizer.zero_grad() | |
| step_count += 1 | |
| elif fabric.device.type == "xla": | |
| xm.mark_step() | |
| t1 = time.perf_counter() | |
| total_lengths += input_ids.size(1) | |
| speed_monitor.on_train_batch_end( | |
| (iter_num + 1) * micro_batch_size, | |
| t1 - total_t0, | |
| # this assumes that device FLOPs are the same and that all devices have the same batch size | |
| fabric.world_size, | |
| flops_per_batch=measured_flops, | |
| lengths=total_lengths, | |
| ) | |
| if iter_num % log_interval == 0: | |
| fabric.print( | |
| f"iter {iter_num} step {step_count}: loss {loss.item():.4f}, iter time:" | |
| f" {(t1 - iter_t0) * 1000:.2f}ms{' (optimizer.step)' if not is_accumulating else ''}" | |
| ) | |
| # -> W&B Logging | |
| fabric.log_dict( | |
| "iter": iter_num, | |
| "step": step_count, | |
| "loss": loss.item(), | |
| "iter time (ms)": (t1 - iter_t0) * 1000, | |
| ) | |
| if not is_accumulating and step_count % eval_interval == 0: | |
| t0 = time.perf_counter() | |
| val_loss = validate(fabric, model, val_data, tokenizer, longest_seq_length) | |
| t1 = time.perf_counter() - t0 | |
| speed_monitor.eval_end(t1) | |
| fabric.print(f"step {iter_num}: val loss {val_loss:.4f}, val time: {t1 * 1000:.2f}ms") | |
| # -> W&B Logging | |
| fabric.log_dict( | |
| "val_iter": iter_num, | |
| "val_loss": val_loss, | |
| "val_time (ms)": t1 * 1000, | |
| ) | |
| fabric.barrier() | |
| if not is_accumulating and step_count % save_interval == 0: | |
| checkpoint_path = out_dir / f"iter-{iter_num:06d}-ckpt.pth" | |
| save_lora_checkpoint(fabric, model, checkpoint_path) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment