Created
June 8, 2025 22:12
-
-
Save fepegar/15ae4b68c0da68ae97aba9a2ddc0b412 to your computer and use it in GitHub Desktop.
run-capi.ipynb
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"provenance": [], | |
"authorship_tag": "ABX9TyP72aZCjnR5fkDHhkgn26Gc", | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"language_info": { | |
"name": "python" | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/fepegar/15ae4b68c0da68ae97aba9a2ddc0b412/run-capi.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"%pip install --quiet einops jaxtyping numpy omegaconf rich scikit-learn torch torchvision" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "G3vj01I9SjRj", | |
"outputId": "81fbc920-6a13-4b6a-91ed-e4e2e3617e5a" | |
}, | |
"execution_count": 2, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m55.4/55.4 kB\u001b[0m \u001b[31m2.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m363.4/363.4 MB\u001b[0m \u001b[31m4.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m13.8/13.8 MB\u001b[0m \u001b[31m97.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m24.6/24.6 MB\u001b[0m \u001b[31m73.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m883.7/883.7 kB\u001b[0m \u001b[31m46.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m664.8/664.8 MB\u001b[0m \u001b[31m2.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m211.5/211.5 MB\u001b[0m \u001b[31m4.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m56.3/56.3 MB\u001b[0m \u001b[31m12.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m127.9/127.9 MB\u001b[0m \u001b[31m8.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m207.5/207.5 MB\u001b[0m \u001b[31m4.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m21.1/21.1 MB\u001b[0m \u001b[31m59.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[?25h" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"import torch\n", | |
"\n", | |
"capi_vitl14_p205 = torch.hub.load('facebookresearch/capi:main', 'capi_vitl14_p205')" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "6Q2r8CuKSxd_", | |
"outputId": "aca470ea-8785-4471-ac0c-69adae07f99c" | |
}, | |
"execution_count": 3, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stderr", | |
"text": [ | |
"Downloading: \"https://github.com/facebookresearch/capi/zipball/main\" to /root/.cache/torch/hub/main.zip\n", | |
"WARNING:utils:Can't import sklearnex. If installed, that speeds up scikit-learn 10-100x\n", | |
"Downloading: \"https://dl.fbaipublicfiles.com/capi/capi_vitl14_p205.pth\" to /root/.cache/torch/hub/checkpoints/capi_vitl14_p205.pth\n", | |
"100%|██████████| 900M/900M [00:47<00:00, 20.0MB/s]\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"img = torch.zeros(1, 3, 224, 224)\n", | |
"global_repr, registers, feature_map = capi_vitl14_p205(img)\n", | |
"global_repr.shape, registers.shape, feature_map.shape" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "OETucVAkTovr", | |
"outputId": "b201ea10-1ec9-4fb2-b766-9823a6dae3ec" | |
}, | |
"execution_count": 4, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"(torch.Size([1, 1024]),\n", | |
" torch.Size([1, 16, 1024]),\n", | |
" torch.Size([1, 16, 16, 1024]))" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 4 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"%load_ext rich" | |
], | |
"metadata": { | |
"id": "sTZh4MX4mam9" | |
}, | |
"execution_count": 6, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"capi_vitl14_p205" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 981 | |
}, | |
"id": "amKIooVSmbtH", | |
"outputId": "2272eb5a-433a-4974-dae1-bedb51a7b401" | |
}, | |
"execution_count": 7, | |
"outputs": [ | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/plain": [], | |
"text/html": [ | |
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n" | |
] | |
}, | |
"metadata": {} | |
}, | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"\n", | |
"\u001b[1;35mEncoderDecoder\u001b[0m\u001b[1m(\u001b[0m\n", | |
" \u001b[1m(\u001b[0mencoder\u001b[1m)\u001b[0m: \u001b[1;35mTransformer\u001b[0m\u001b[1m(\u001b[0m\n", | |
" \u001b[1m(\u001b[0mblocks\u001b[1m)\u001b[0m: \u001b[1;35mModuleList\u001b[0m\u001b[1m(\u001b[0m\n", | |
" \u001b[1m(\u001b[0m\u001b[1;36m0\u001b[0m-\u001b[1;36m23\u001b[0m\u001b[1m)\u001b[0m: \u001b[1;36m24\u001b[0m x \u001b[1;35mBlock\u001b[0m\u001b[1m(\u001b[0m\n", | |
" \u001b[1m(\u001b[0mresidual1\u001b[1m)\u001b[0m: \u001b[1;35mEfficientResidual\u001b[0m\u001b[1m(\u001b[0m\n", | |
" \u001b[1m(\u001b[0mnorm\u001b[1m)\u001b[0m: \u001b[1;35mRMSNorm\u001b[0m\u001b[1m(\u001b[0m\u001b[1m(\u001b[0m\u001b[1;36m1024\u001b[0m,\u001b[1m)\u001b[0m, \u001b[33meps\u001b[0m=\u001b[1;36m1e\u001b[0m\u001b[1;36m-05\u001b[0m, \u001b[33melementwise_affine\u001b[0m=\u001b[3;92mTrue\u001b[0m\u001b[1m)\u001b[0m\n", | |
" \u001b[1m(\u001b[0mfn\u001b[1m)\u001b[0m: \u001b[1;35mAttention\u001b[0m\u001b[1m(\u001b[0m\n", | |
" \u001b[1m(\u001b[0mq_proj\u001b[1m)\u001b[0m: \u001b[1;35mLinear\u001b[0m\u001b[1m(\u001b[0m\u001b[33min_features\u001b[0m=\u001b[1;36m1024\u001b[0m, \u001b[33mout_features\u001b[0m=\u001b[1;36m1024\u001b[0m, \u001b[33mbias\u001b[0m=\u001b[3;91mFalse\u001b[0m\u001b[1m)\u001b[0m\n", | |
" \u001b[1m(\u001b[0mk_proj\u001b[1m)\u001b[0m: \u001b[1;35mLinear\u001b[0m\u001b[1m(\u001b[0m\u001b[33min_features\u001b[0m=\u001b[1;36m1024\u001b[0m, \u001b[33mout_features\u001b[0m=\u001b[1;36m1024\u001b[0m, \u001b[33mbias\u001b[0m=\u001b[3;91mFalse\u001b[0m\u001b[1m)\u001b[0m\n", | |
" \u001b[1m(\u001b[0mv_proj\u001b[1m)\u001b[0m: \u001b[1;35mLinear\u001b[0m\u001b[1m(\u001b[0m\u001b[33min_features\u001b[0m=\u001b[1;36m1024\u001b[0m, \u001b[33mout_features\u001b[0m=\u001b[1;36m1024\u001b[0m, \u001b[33mbias\u001b[0m=\u001b[3;91mFalse\u001b[0m\u001b[1m)\u001b[0m\n", | |
" \u001b[1m(\u001b[0mproj\u001b[1m)\u001b[0m: \u001b[1;35mLinear\u001b[0m\u001b[1m(\u001b[0m\u001b[33min_features\u001b[0m=\u001b[1;36m1024\u001b[0m, \u001b[33mout_features\u001b[0m=\u001b[1;36m1024\u001b[0m, \u001b[33mbias\u001b[0m=\u001b[3;91mFalse\u001b[0m\u001b[1m)\u001b[0m\n", | |
" \u001b[1m(\u001b[0mrope\u001b[1m)\u001b[0m: \u001b[1;35mRope\u001b[0m\u001b[1m(\u001b[0m\u001b[1m)\u001b[0m\n", | |
" \u001b[1m)\u001b[0m\n", | |
" \u001b[1m)\u001b[0m\n", | |
" \u001b[1m(\u001b[0mresidual2\u001b[1m)\u001b[0m: \u001b[1;35mEfficientResidual\u001b[0m\u001b[1m(\u001b[0m\n", | |
" \u001b[1m(\u001b[0mnorm\u001b[1m)\u001b[0m: \u001b[1;35mRMSNorm\u001b[0m\u001b[1m(\u001b[0m\u001b[1m(\u001b[0m\u001b[1;36m1024\u001b[0m,\u001b[1m)\u001b[0m, \u001b[33meps\u001b[0m=\u001b[1;36m1e\u001b[0m\u001b[1;36m-05\u001b[0m, \u001b[33melementwise_affine\u001b[0m=\u001b[3;92mTrue\u001b[0m\u001b[1m)\u001b[0m\n", | |
" \u001b[1m(\u001b[0mfn\u001b[1m)\u001b[0m: \u001b[1;35mMlp\u001b[0m\u001b[1m(\u001b[0m\n", | |
" \u001b[1m(\u001b[0mfc1\u001b[1m)\u001b[0m: \u001b[1;35mLinear\u001b[0m\u001b[1m(\u001b[0m\u001b[33min_features\u001b[0m=\u001b[1;36m1024\u001b[0m, \u001b[33mout_features\u001b[0m=\u001b[1;36m4096\u001b[0m, \u001b[33mbias\u001b[0m=\u001b[3;91mFalse\u001b[0m\u001b[1m)\u001b[0m\n", | |
" \u001b[1m(\u001b[0mact\u001b[1m)\u001b[0m: \u001b[1;35mGELU\u001b[0m\u001b[1m(\u001b[0m\u001b[33mapproximate\u001b[0m=\u001b[32m'none'\u001b[0m\u001b[1m)\u001b[0m\n", | |
" \u001b[1m(\u001b[0mfc2\u001b[1m)\u001b[0m: \u001b[1;35mLinear\u001b[0m\u001b[1m(\u001b[0m\u001b[33min_features\u001b[0m=\u001b[1;36m4096\u001b[0m, \u001b[33mout_features\u001b[0m=\u001b[1;36m1024\u001b[0m, \u001b[33mbias\u001b[0m=\u001b[3;91mFalse\u001b[0m\u001b[1m)\u001b[0m\n", | |
" \u001b[1m)\u001b[0m\n", | |
" \u001b[1m)\u001b[0m\n", | |
" \u001b[1m)\u001b[0m\n", | |
" \u001b[1m)\u001b[0m\n", | |
" \u001b[1m)\u001b[0m\n", | |
" \u001b[1m(\u001b[0mdecoder\u001b[1m)\u001b[0m: \u001b[1;35mTransformer\u001b[0m\u001b[1m(\u001b[0m\n", | |
" \u001b[1m(\u001b[0mblocks\u001b[1m)\u001b[0m: \u001b[1;35mModuleList\u001b[0m\u001b[1m(\u001b[0m\n", | |
" \u001b[1m(\u001b[0m\u001b[1;36m0\u001b[0m-\u001b[1;36m11\u001b[0m\u001b[1m)\u001b[0m: \u001b[1;36m12\u001b[0m x \u001b[1;35mBlock\u001b[0m\u001b[1m(\u001b[0m\n", | |
" \u001b[1m(\u001b[0mresidual1\u001b[1m)\u001b[0m: \u001b[1;35mEfficientResidual\u001b[0m\u001b[1m(\u001b[0m\n", | |
" \u001b[1m(\u001b[0mnorm\u001b[1m)\u001b[0m: \u001b[1;35mRMSNorm\u001b[0m\u001b[1m(\u001b[0m\u001b[1m(\u001b[0m\u001b[1;36m1024\u001b[0m,\u001b[1m)\u001b[0m, \u001b[33meps\u001b[0m=\u001b[1;36m1e\u001b[0m\u001b[1;36m-05\u001b[0m, \u001b[33melementwise_affine\u001b[0m=\u001b[3;92mTrue\u001b[0m\u001b[1m)\u001b[0m\n", | |
" \u001b[1m(\u001b[0mfn\u001b[1m)\u001b[0m: \u001b[1;35mAttention\u001b[0m\u001b[1m(\u001b[0m\n", | |
" \u001b[1m(\u001b[0mq_proj\u001b[1m)\u001b[0m: \u001b[1;35mLinear\u001b[0m\u001b[1m(\u001b[0m\u001b[33min_features\u001b[0m=\u001b[1;36m1024\u001b[0m, \u001b[33mout_features\u001b[0m=\u001b[1;36m1024\u001b[0m, \u001b[33mbias\u001b[0m=\u001b[3;91mFalse\u001b[0m\u001b[1m)\u001b[0m\n", | |
" \u001b[1m(\u001b[0mk_proj\u001b[1m)\u001b[0m: \u001b[1;35mLinear\u001b[0m\u001b[1m(\u001b[0m\u001b[33min_features\u001b[0m=\u001b[1;36m1024\u001b[0m, \u001b[33mout_features\u001b[0m=\u001b[1;36m1024\u001b[0m, \u001b[33mbias\u001b[0m=\u001b[3;91mFalse\u001b[0m\u001b[1m)\u001b[0m\n", | |
" \u001b[1m(\u001b[0mv_proj\u001b[1m)\u001b[0m: \u001b[1;35mLinear\u001b[0m\u001b[1m(\u001b[0m\u001b[33min_features\u001b[0m=\u001b[1;36m1024\u001b[0m, \u001b[33mout_features\u001b[0m=\u001b[1;36m1024\u001b[0m, \u001b[33mbias\u001b[0m=\u001b[3;91mFalse\u001b[0m\u001b[1m)\u001b[0m\n", | |
" \u001b[1m(\u001b[0mproj\u001b[1m)\u001b[0m: \u001b[1;35mLinear\u001b[0m\u001b[1m(\u001b[0m\u001b[33min_features\u001b[0m=\u001b[1;36m1024\u001b[0m, \u001b[33mout_features\u001b[0m=\u001b[1;36m1024\u001b[0m, \u001b[33mbias\u001b[0m=\u001b[3;91mFalse\u001b[0m\u001b[1m)\u001b[0m\n", | |
" \u001b[1m(\u001b[0mrope\u001b[1m)\u001b[0m: \u001b[1;35mRope\u001b[0m\u001b[1m(\u001b[0m\u001b[1m)\u001b[0m\n", | |
" \u001b[1m)\u001b[0m\n", | |
" \u001b[1m)\u001b[0m\n", | |
" \u001b[1m(\u001b[0mresidual2\u001b[1m)\u001b[0m: \u001b[1;35mEfficientResidual\u001b[0m\u001b[1m(\u001b[0m\n", | |
" \u001b[1m(\u001b[0mnorm\u001b[1m)\u001b[0m: \u001b[1;35mRMSNorm\u001b[0m\u001b[1m(\u001b[0m\u001b[1m(\u001b[0m\u001b[1;36m1024\u001b[0m,\u001b[1m)\u001b[0m, \u001b[33meps\u001b[0m=\u001b[1;36m1e\u001b[0m\u001b[1;36m-05\u001b[0m, \u001b[33melementwise_affine\u001b[0m=\u001b[3;92mTrue\u001b[0m\u001b[1m)\u001b[0m\n", | |
" \u001b[1m(\u001b[0mfn\u001b[1m)\u001b[0m: \u001b[1;35mMlp\u001b[0m\u001b[1m(\u001b[0m\n", | |
" \u001b[1m(\u001b[0mfc1\u001b[1m)\u001b[0m: \u001b[1;35mLinear\u001b[0m\u001b[1m(\u001b[0m\u001b[33min_features\u001b[0m=\u001b[1;36m1024\u001b[0m, \u001b[33mout_features\u001b[0m=\u001b[1;36m4096\u001b[0m, \u001b[33mbias\u001b[0m=\u001b[3;91mFalse\u001b[0m\u001b[1m)\u001b[0m\n", | |
" \u001b[1m(\u001b[0mact\u001b[1m)\u001b[0m: \u001b[1;35mGELU\u001b[0m\u001b[1m(\u001b[0m\u001b[33mapproximate\u001b[0m=\u001b[32m'none'\u001b[0m\u001b[1m)\u001b[0m\n", | |
" \u001b[1m(\u001b[0mfc2\u001b[1m)\u001b[0m: \u001b[1;35mLinear\u001b[0m\u001b[1m(\u001b[0m\u001b[33min_features\u001b[0m=\u001b[1;36m4096\u001b[0m, \u001b[33mout_features\u001b[0m=\u001b[1;36m1024\u001b[0m, \u001b[33mbias\u001b[0m=\u001b[3;91mFalse\u001b[0m\u001b[1m)\u001b[0m\n", | |
" \u001b[1m)\u001b[0m\n", | |
" \u001b[1m)\u001b[0m\n", | |
" \u001b[1m)\u001b[0m\n", | |
" \u001b[1m)\u001b[0m\n", | |
" \u001b[1m)\u001b[0m\n", | |
" \u001b[1m(\u001b[0mpatch_embed\u001b[1m)\u001b[0m: \u001b[1;35mConv2d\u001b[0m\u001b[1m(\u001b[0m\u001b[1;36m3\u001b[0m, \u001b[1;36m1024\u001b[0m, \u001b[33mkernel_size\u001b[0m=\u001b[1m(\u001b[0m\u001b[1;36m14\u001b[0m, \u001b[1;36m14\u001b[0m\u001b[1m)\u001b[0m, \u001b[33mstride\u001b[0m=\u001b[1m(\u001b[0m\u001b[1;36m14\u001b[0m, \u001b[1;36m14\u001b[0m\u001b[1m)\u001b[0m\u001b[1m)\u001b[0m\n", | |
" \u001b[1m(\u001b[0menc_norm\u001b[1m)\u001b[0m: \u001b[1;35mRMSNorm\u001b[0m\u001b[1m(\u001b[0m\u001b[1m(\u001b[0m\u001b[1;36m1024\u001b[0m,\u001b[1m)\u001b[0m, \u001b[33meps\u001b[0m=\u001b[1;36m1e\u001b[0m\u001b[1;36m-05\u001b[0m, \u001b[33melementwise_affine\u001b[0m=\u001b[3;91mFalse\u001b[0m\u001b[1m)\u001b[0m\n", | |
" \u001b[1m(\u001b[0mdec_norm\u001b[1m)\u001b[0m: \u001b[1;35mRMSNorm\u001b[0m\u001b[1m(\u001b[0m\u001b[1m(\u001b[0m\u001b[1;36m1024\u001b[0m,\u001b[1m)\u001b[0m, \u001b[33meps\u001b[0m=\u001b[1;36m1e\u001b[0m\u001b[1;36m-05\u001b[0m, \u001b[33melementwise_affine\u001b[0m=\u001b[3;91mFalse\u001b[0m\u001b[1m)\u001b[0m\n", | |
"\u001b[1m)\u001b[0m" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 7 | |
} | |
] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment