Skip to content

Instantly share code, notes, and snippets.

@travishsu
Created April 24, 2025 02:36
Show Gist options
  • Save travishsu/cebc93f45ef2f790bd6ce50f6d3fe9ad to your computer and use it in GitHub Desktop.
Save travishsu/cebc93f45ef2f790bd6ce50f6d3fe9ad to your computer and use it in GitHub Desktop.
# uvx --with torch --with torchvision ipython
from torchvision.models import get_model
model = get_model('resnet50', pretrained=True)
# https://pytorch.org/vision/main/feature_extraction.html
from torchvision.models.feature_extraction import create_feature_extractor
return_nodes = {f'layer{i}': f'layer{i}' for i in range(1, 5)}
extractor = create_feature_extractor(model, return_nodes)
import torch
img = torch.randn((3, 3, 256, 256))
features = extractor(img)
for i in range(1, 5):
print(i, features[f'layer{i}'].shape)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment