Skip to content

Instantly share code, notes, and snippets.

View rwightman's full-sized avatar

Ross Wightman rwightman

View GitHub Profile
@rwightman
rwightman / effresnetcomparison.ipynb
Created July 1, 2019 21:23
EffResNetComparison
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
@rwightman
rwightman / triplet_loss.py
Last active November 21, 2023 10:31
Hacky PyTorch Batch-Hard Triplet Loss and PK samplers
import torch
from torch import nn
import torch.nn.functional as F
from collections import OrderedDict
import math
def pdist(v):
dist = torch.norm(v[:, None] - v, dim=2, p=2)
return dist
@rwightman
rwightman / seq_stroke_net.py
Last active September 24, 2021 09:31
Two RNN (1d CNN + LSTM) models for the Kaggle QuickDraw Challenge.
''' Two sample RNN (1d CNN + LSTM) networks used in the Kaggle
QuickDraw Challenge (https://www.kaggle.com/c/quickdraw-doodle-recognition)
Both of these networks expect a tuple input with first element being the sequences
and second being the sequence lengths (typical sorted packed format). The sequence tensor
should adhere to the following shape: (batch_size, channels, seq_len).
Where channels consists of stroke [x, y, t, end]. End indicates whether the
stroke is the last in a segment (pen up). It could easily be changed to start
(pen down) or combo of both.
@rwightman
rwightman / median_pool.py
Last active August 13, 2024 10:57
PyTorch MedianPool (MedianFilter)
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.utils import _pair, _quadruple
class MedianPool2d(nn.Module):
""" Median pool (usable as median filter when stride=1) module.