Hparams for tiny test timm
models, all pretty similar, vit3 extended the schedule from the default 1600 epochs to 1800. All based on MobileNetv4 template w/ adamw reduced beta1, grinding out the training. Some of the smallest models had a bit less AA magnitude (3), slightly higher capacity ones increased to m5 or m6.
Included yaml files are timm train script configs for training MobileNetV4 models in timm (see on HF Hub: https://huggingface.co/collections/timm/mobilenetv4-pretrained-weights-6669c22cda4db4244def9637)
Note the # of GPUs, this needs to be taken into consideration for global batch size equivalence, and LR scaling.
Also note, some models have lr
set to a non null value, this LR is used directly if set. Otherwise, it falls back to lr_base
and the used rate is calculated based on lr_base_size
and a sqrt scaling according to the global batch size.
Models with ix
in the tag are using an alternative init for the MQA attention model projections, xavier (glorot) uniform instead of the efficientnet/mobilenet defaults. This seemed to improve stability of the hybrid models, allow a larger (closer to 1) beta2 for adam, otherwise beta2 on the adam, or the LR needed to be reduced to avoid instability with the hybrids.
aa: rand-m8-inc1-mstd1.0-n4 | |
amp: true | |
amp_dtype: float16 | |
amp_impl: native | |
aug_repeats: 3.0 | |
aug_splits: 0 | |
batch_size: 256 | |
bce_loss: false | |
bce_target_thresh: null | |
bn_eps: null |
import math | |
import os | |
from collections import defaultdict | |
from pathlib import Path | |
from huggingface_hub import CommitOperationAdd, preupload_lfs_files, create_commit | |
# fast transfers using a Rust library, `pip install hf-transfer` | |
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" |
import torch | |
import timm | |
from torchvision.models.feature_extraction import get_graph_node_names | |
timm.layers.set_fused_attn(False) # disable F.sdpa so softmax node is exposed | |
mm = timm.create_model('vit_medium_patch16_gap_256.sw_in12k_ft_in1k', pretrained=True) | |
softmax_nodes = [n for n in get_graph_node_names(mm)[0] if 'softmax' in n] | |
ff = timm.models.create_feature_extractor(mm, softmax_nodes) | |
with torch.no_grad(): |
model | image_size | embed_dim | gmacs | macts | mparams | image_gmacs | image_macts | image_mparams | text_gmacs | text_macts | text_mparams | |
---|---|---|---|---|---|---|---|---|---|---|---|---|
ViT-S-32-alt | 224 | 256 | 1.78 | 4.71 | 43.22 | 1.15 | 2.5 | 22.59 | 0.64 | 2.21 | 20.63 | |
ViT-S-32 | 224 | 384 | 2.84 | 6.48 | 63.09 | 1.15 | 2.5 | 22.64 | 1.69 | 3.98 | 40.44 | |
ViT-M-32-alt | 224 | 384 | 3.69 | 7.31 | 80.07 | 2.0 | 3.34 | 39.63 | 1.69 | 3.98 | 40.44 | |
ViT-M-32 | 224 | 512 | 4.98 | 8.64 | 103.12 | 2.0 | 3.34 | 39.69 | 2.98 | 5.3 | 63.43 | |
ViT-S-16-alt | 224 | 256 | 5.25 | 14.16 | 42.4 | 4.61 | 11.95 | 21.76 | 0.64 | 2.21 | 20.63 | |
ViT-S-16 | 224 | 384 | 6.3 | 15.92 | 62.26 | 4.61 | 11.95 | 21.81 | 1.69 | 3.98 | 40.44 | |
ViT-B-32-quickgelu | 224 | 512 | 7.4 | 10.31 | 151.28 | 4.41 | 5.01 | 87.85 | 2.98 | 5.3 | 63.43 | |
ViT-B-32 | 224 | 512 | 7.4 | 10.31 | 151.28 | 4.41 | 5.01 | 87.85 | 2.98 | 5.3 | 63.43 | |
convnext_tiny | 224 | 1024 | 7.46 | 18.74 | 92.3 | 4.47 | 13.44 | 28.61 | 2.98 | 5.3 | 63.69 |
Some hparams related to RegNets (and other nets) in TPU training series https://github.com/rwightman/pytorch-image-models/releases/tag/v0.1-tpu-weights
All models trained on x8 TPUs, so global batch == batch_size * 8
If in the weight name it says ra3
it means rmsproptf + mixup + cutmix + rand erasing + (usually) lr noise + rand-aug + head dropout + drop path (stochastic depth). Older ra2
scheme was very similar but no cutmix and rand-aug was always using normal sampling (mstd0.5
or mstd1.0
) for rand-aug magnitude, where as ra3
is often (not always) using uniform sampling (mstd101
).
Some weights were trained with sgd + grad clipping (cx
in name where x is one of h, 1, 2, 3 ), h = amped up augreg.
I believe the 064 regnety was very close with both the ra3 and sgd approach, hparams I have kept were the sgd ones but I believe published weights were rmsproptf and edged out by a hair.
A variety of hparams used to train vit, convnext, vit-hybrids (maxvit, coatnet) recently in timm
All variations on the same theme (DeiT / Swin pretraining) but with different tweaks here and there.
These were all run on 4-8 GPU or TPU devices, they use --lr-base
which rescales the LR automatically based on global batch size (relative to --lr-base-size
) so adapting to different GPU counts will work well within a range, running at significanly lower or higher global batch sizes will require re-running a LR search.
More recntly, DeiT-III has shown to be a very compelling set of hparams for vit like models, but I've yet to do full runs myself, but theirs can be adapted to timm train scripts (3A aug added recently). https://github.com/facebookresearch/deit/blob/main/README_revenge.md
To use the yaml files directly w/ timm train script.
Hparams were run on 8x A100 for in12k or 12k fine-tune runs and 4x V100 for the rest, so global batch size = 320 * 4, etc and should be rescaled using a sqrt rule if changing the global batch size.
model | infer_samples_per_sec | infer_step_time | infer_batch_size | infer_img_size | train_samples_per_sec | train_step_time | train_batch_size | train_img_size | param_count | |
---|---|---|---|---|---|---|---|---|---|---|
vit_small_patch16_224 | 2444.7 | 104.691 | 256 | 224 | 955.88 | 267.078 | 256 | 224 | 22.05 | |
vit_relpos_medium_patch16_224 | 1107.38 | 231.158 | 256 | 224 | 502.75 | 253.69 | 128 | 224 | 38.75 | |
vit_base_patch16_224 | 1013.88 | 252.477 | 256 | 224 | 358.36 | 356.433 | 128 | 224 | 86.57 | |
vit_base_patch16_384 | 288.27 | 888.045 | 256 | 384 | 102.82 | 300.795 | 31 | 384 | 86.86 |