Skip to content

Instantly share code, notes, and snippets.

View rwightman's full-sized avatar

Ross Wightman rwightman

View GitHub Profile

Hparams for tiny test timm models, all pretty similar, vit3 extended the schedule from the default 1600 epochs to 1800. All based on MobileNetv4 template w/ adamw reduced beta1, grinding out the training. Some of the smallest models had a bit less AA magnitude (3), slightly higher capacity ones increased to m5 or m6.

@rwightman
rwightman / _README_MobileNetV4.md
Last active November 1, 2024 02:57
MobileNetV4 hparams

MobileNetV4 Hparams

Included yaml files are timm train script configs for training MobileNetV4 models in timm (see on HF Hub: https://huggingface.co/collections/timm/mobilenetv4-pretrained-weights-6669c22cda4db4244def9637)

Note the # of GPUs, this needs to be taken into consideration for global batch size equivalence, and LR scaling.

Also note, some models have lr set to a non null value, this LR is used directly if set. Otherwise, it falls back to lr_base and the used rate is calculated based on lr_base_size and a sqrt scaling according to the global batch size.

Models with ix in the tag are using an alternative init for the MQA attention model projections, xavier (glorot) uniform instead of the efficientnet/mobilenet defaults. This seemed to improve stability of the hybrid models, allow a larger (closer to 1) beta2 for adam, otherwise beta2 on the adam, or the LR needed to be reduced to avoid instability with the hybrids.

@rwightman
rwightman / train-vit_base_rope_reg1-4gpu-in1k.yaml
Last active May 14, 2024 18:09
Searching for Better Vit Baselines Hparams
aa: rand-m8-inc1-mstd1.0-n4
amp: true
amp_dtype: float16
amp_impl: native
aug_repeats: 3.0
aug_splits: 0
batch_size: 256
bce_loss: false
bce_target_thresh: null
bn_eps: null
import math
import os
from collections import defaultdict
from pathlib import Path
from huggingface_hub import CommitOperationAdd, preupload_lfs_files, create_commit
# fast transfers using a Rust library, `pip install hf-transfer`
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
@rwightman
rwightman / timm_vit_attention_map.py
Created November 21, 2023 18:39
Extract attention maps from timm vits' with Torch FX
import torch
import timm
from torchvision.models.feature_extraction import get_graph_node_names
timm.layers.set_fused_attn(False) # disable F.sdpa so softmax node is exposed
mm = timm.create_model('vit_medium_patch16_gap_256.sw_in12k_ft_in1k', pretrained=True)
softmax_nodes = [n for n in get_graph_node_names(mm)[0] if 'softmax' in n]
ff = timm.models.create_feature_extractor(mm, softmax_nodes)
with torch.no_grad():
model image_size embed_dim gmacs macts mparams image_gmacs image_macts image_mparams text_gmacs text_macts text_mparams
ViT-S-32-alt 224 256 1.78 4.71 43.22 1.15 2.5 22.59 0.64 2.21 20.63
ViT-S-32 224 384 2.84 6.48 63.09 1.15 2.5 22.64 1.69 3.98 40.44
ViT-M-32-alt 224 384 3.69 7.31 80.07 2.0 3.34 39.63 1.69 3.98 40.44
ViT-M-32 224 512 4.98 8.64 103.12 2.0 3.34 39.69 2.98 5.3 63.43
ViT-S-16-alt 224 256 5.25 14.16 42.4 4.61 11.95 21.76 0.64 2.21 20.63
ViT-S-16 224 384 6.3 15.92 62.26 4.61 11.95 21.81 1.69 3.98 40.44
ViT-B-32-quickgelu 224 512 7.4 10.31 151.28 4.41 5.01 87.85 2.98 5.3 63.43
ViT-B-32 224 512 7.4 10.31 151.28 4.41 5.01 87.85 2.98 5.3 63.43
convnext_tiny 224 1024 7.46 18.74 92.3 4.47 13.44 28.61 2.98 5.3 63.69

Some hparams related to RegNets (and other nets) in TPU training series https://github.com/rwightman/pytorch-image-models/releases/tag/v0.1-tpu-weights

All models trained on x8 TPUs, so global batch == batch_size * 8

If in the weight name it says ra3 it means rmsproptf + mixup + cutmix + rand erasing + (usually) lr noise + rand-aug + head dropout + drop path (stochastic depth). Older ra2 scheme was very similar but no cutmix and rand-aug was always using normal sampling (mstd0.5 or mstd1.0) for rand-aug magnitude, where as ra3 is often (not always) using uniform sampling (mstd101).

Some weights were trained with sgd + grad clipping (cx in name where x is one of h, 1, 2, 3 ), h = amped up augreg.

I believe the 064 regnety was very close with both the ra3 and sgd approach, hparams I have kept were the sgd ones but I believe published weights were rmsproptf and edged out by a hair.

@rwightman
rwightman / _timm_hparams.md
Last active May 30, 2023 05:18
Recent timm hparams...

A variety of hparams used to train vit, convnext, vit-hybrids (maxvit, coatnet) recently in timm

All variations on the same theme (DeiT / Swin pretraining) but with different tweaks here and there.

These were all run on 4-8 GPU or TPU devices, they use --lr-base which rescales the LR automatically based on global batch size (relative to --lr-base-size) so adapting to different GPU counts will work well within a range, running at significanly lower or higher global batch sizes will require re-running a LR search.

More recntly, DeiT-III has shown to be a very compelling set of hparams for vit like models, but I've yet to do full runs myself, but theirs can be adapted to timm train scripts (3A aug added recently). https://github.com/facebookresearch/deit/blob/main/README_revenge.md

To use the yaml files directly w/ timm train script.

Hparams were run on 8x A100 for in12k or 12k fine-tune runs and 4x V100 for the rest, so global batch size = 320 * 4, etc and should be rescaled using a sqrt rule if changing the global batch size.

@rwightman
rwightman / vit-aot.csv
Created July 13, 2022 05:22
timm vit models, eager vs aot vs torchscript, AMP, PyTorch 1.12
model infer_samples_per_sec infer_step_time infer_batch_size infer_img_size train_samples_per_sec train_step_time train_batch_size train_img_size param_count
vit_small_patch16_224 2444.7 104.691 256 224 955.88 267.078 256 224 22.05
vit_relpos_medium_patch16_224 1107.38 231.158 256 224 502.75 253.69 128 224 38.75
vit_base_patch16_224 1013.88 252.477 256 224 358.36 356.433 128 224 86.57
vit_base_patch16_384 288.27 888.045 256 384 102.82 300.795 31 384 86.86