Skip to content

Instantly share code, notes, and snippets.

# uvx --with torch --with torchvision ipython
from torchvision.models import get_model
model = get_model('resnet50', pretrained=True)
# https://pytorch.org/vision/main/feature_extraction.html
from torchvision.models.feature_extraction import create_feature_extractor
return_nodes = {f'layer{i}': f'layer{i}' for i in range(1, 5)}
extractor = create_feature_extractor(model, return_nodes)
import torch
# uvx --with "git+https://github.com/travishsu/mimm.git" --with numpy ipython
import mlx.core as mx
from mimm import get_model, list_models
model = get_model('resnet50', pretrained=True)
img = mx.random.normal((3, 256, 256, 3))
features = model.features(img)
for layer, feature in features.items():
print(layer, feature.shape)