Skip to content

Instantly share code, notes, and snippets.

View rwightman's full-sized avatar

Ross Wightman rwightman

View GitHub Profile
@rwightman
rwightman / median_pool.py
Last active August 13, 2024 10:57
PyTorch MedianPool (MedianFilter)
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.utils import _pair, _quadruple
class MedianPool2d(nn.Module):
""" Median pool (usable as median filter when stride=1) module.
@rwightman
rwightman / seq_stroke_net.py
Last active September 24, 2021 09:31
Two RNN (1d CNN + LSTM) models for the Kaggle QuickDraw Challenge.
''' Two sample RNN (1d CNN + LSTM) networks used in the Kaggle
QuickDraw Challenge (https://www.kaggle.com/c/quickdraw-doodle-recognition)
Both of these networks expect a tuple input with first element being the sequences
and second being the sequence lengths (typical sorted packed format). The sequence tensor
should adhere to the following shape: (batch_size, channels, seq_len).
Where channels consists of stroke [x, y, t, end]. End indicates whether the
stroke is the last in a segment (pen up). It could easily be changed to start
(pen down) or combo of both.
@rwightman
rwightman / triplet_loss.py
Last active November 21, 2023 10:31
Hacky PyTorch Batch-Hard Triplet Loss and PK samplers
import torch
from torch import nn
import torch.nn.functional as F
from collections import OrderedDict
import math
def pdist(v):
dist = torch.norm(v[:, None] - v, dim=2, p=2)
return dist
@rwightman
rwightman / effresnetcomparison.ipynb
Created July 1, 2019 21:23
EffResNetComparison
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
@rwightman
rwightman / image_folder_tar.py
Created July 24, 2019 05:01
PyTorch ImageFolder style dataset for reading directly from tarfile
import torch.utils.data as data
import os
import re
import torch
import tarfile
from PIL import Image
IMG_EXTENSIONS = ['.png', '.jpg', '.jpeg']
@rwightman
rwightman / bench_by_infer.csv
Created March 6, 2021 06:22
PyTorch Bench (1.8, 1.7.1, NGC 21.02, NGC 20.12)
model gpu env cl infer_samples_per_sec infer_step_time infer_batch_size train_samples_per_sec train_step_time train_batch_size param_count img_size
efficientnet_b0 rtx3090 ngc2102 True 7179.22 0.139 512 1628.51 0.609 256 5.29 224
efficientnet_b0 rtx3090 ngc2012 True 6527.77 0.153 512 1504.58 0.654 256 5.29 224
efficientnet_b0 v100_32 ngc2102 True 6496.56 0.154 512 1556.66 0.638 512 5.29 224
efficientnet_b0 rtx3090 1.7.1cu11.0 True 6020.3 0.166 512 1266.03 0.785 512 5.29 224
efficientnet_b0 rtx3090 1.8cu11.1 True 5979.7 0.167 512 1286.76 0.775 512 5.29 224
efficientnet_b0 v100_32 ngc2012 True 5666.05 0.176 512 1459.05 0.676 512 5.29 224
efficientnet_b0 v100_32 1.8cu11.1 True 5529.09 0.181 512 1444.02 0.688 512 5.29 224
efficientnet_b0 v100_32 1.7.1cu11.0 True 5526.07 0.181 512 1425.38 0.691 512 5.29 224
efficientnet_b0 titanrtx ngc2102 True 5118.38 0.195 512 1156.83 0.862 512 5.29 224
img_size ngc2102 nocl ngc2102 cl pt-181 nocl pt-181 cl ngc2012 nocl ngc2012 cl ngc2103 nocl ngc2103 cl
128 3323.06 1180.6 3494.51 3561.77 3616.33 3534.56 3585.48 3609.14
132 3114.9 1199.74 3037.34 3519.81 3336.5 3460.3 3357.58 3508.03
136 3272.05 1204.48 2995.19 3574 3227.72 3435.07 3328.46 3424.46
140 3200.35 1207.76 2803.09 3587.26 3185.1 3415.24 3221.43 3471.07
144 3194.24 1220.19 2973.52 3683.47 3205.12 3420.44 3220.51 3454.69
148 2942.87 1218.09 2573.74 2900.56 2895.25 3431.24 2964.88 3508.71
152 2886.33 1191.09 2557.25 3043.76 2854.47 3518.21 2986.86 3500.32
156 2879.16 1190.3 2652.08 2945.3 2807.7 3538.78 2952.99 3497.47
160 2654.9 1213.99 2711.74 2822.56 2748.02 3536.74 2834.89 3504.37
@rwightman
rwightman / timm_unet.py
Created April 15, 2021 19:12
An example U-Net using timm features_only functionality.
""" A simple U-Net w/ timm backbone encoder
Based off an old version of Unet in https://github.com/qubvel/segmentation_models.pytorch
Hacked together by Ross Wightman
"""
from typing import Optional, List
import torch
@rwightman
rwightman / effres-agc.yaml
Last active June 24, 2021 23:51
timm config for training an nfnet, load with --config arg, override batch size, lr for your number of GPUs/dist nodes
aa: rand-m6-n5-inc1-mstd1.0
amp: false
apex_amp: false
aug_splits: 0
batch_size: 256
bn_eps: null
bn_momentum: null
bn_tf: false
channels_last: false
checkpoint_hist: 10
@rwightman
rwightman / MLP_hparams.md
Last active June 28, 2021 12:48
MLP model training hparams w/ timm bits and PyTorch XLA on TPU VM

Using TPU VM instance w/ pre-alpha timm bits setup as per: https://github.com/rwightman/pytorch-image-models/tree/bits_and_tpu/timm/bits#readme

python3 launch_xla.py --num-devices 8 train.py gs://my-imagenet --config hparams.yaml

Note the config yaml files do have args that are not used or active based on other overriding code or the state of the current training code. The bits code is under heavy development so these configs will likely need specific revision (currently https://github.com/rwightman/pytorch-image-models/commit/5e95ced5a7763541f7219f35fd155e3fbfe66e8b)

The gMlp hparams are the last (latest) in the series and likely will produce better results than the earlier gmixer / resmlp variants...

Note, for adapting the LR to differenrt batch size. AdamW is being used here and I use a sqrt scaling for the learning rate wrt to (global) batch size. I typicall use linear LR scaling w/ SGD or RMSProp for most from-scratch training.