Skip to content

Instantly share code, notes, and snippets.

@fepegar
Created June 8, 2025 22:12
Show Gist options
  • Save fepegar/15ae4b68c0da68ae97aba9a2ddc0b412 to your computer and use it in GitHub Desktop.
Save fepegar/15ae4b68c0da68ae97aba9a2ddc0b412 to your computer and use it in GitHub Desktop.
run-capi.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"authorship_tag": "ABX9TyP72aZCjnR5fkDHhkgn26Gc",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/fepegar/15ae4b68c0da68ae97aba9a2ddc0b412/run-capi.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"source": [
"%pip install --quiet einops jaxtyping numpy omegaconf rich scikit-learn torch torchvision"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "G3vj01I9SjRj",
"outputId": "81fbc920-6a13-4b6a-91ed-e4e2e3617e5a"
},
"execution_count": 2,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m55.4/55.4 kB\u001b[0m \u001b[31m2.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m363.4/363.4 MB\u001b[0m \u001b[31m4.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m13.8/13.8 MB\u001b[0m \u001b[31m97.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m24.6/24.6 MB\u001b[0m \u001b[31m73.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m883.7/883.7 kB\u001b[0m \u001b[31m46.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m664.8/664.8 MB\u001b[0m \u001b[31m2.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m211.5/211.5 MB\u001b[0m \u001b[31m4.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m56.3/56.3 MB\u001b[0m \u001b[31m12.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m127.9/127.9 MB\u001b[0m \u001b[31m8.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m207.5/207.5 MB\u001b[0m \u001b[31m4.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m21.1/21.1 MB\u001b[0m \u001b[31m59.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25h"
]
}
]
},
{
"cell_type": "code",
"source": [
"import torch\n",
"\n",
"capi_vitl14_p205 = torch.hub.load('facebookresearch/capi:main', 'capi_vitl14_p205')"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "6Q2r8CuKSxd_",
"outputId": "aca470ea-8785-4471-ac0c-69adae07f99c"
},
"execution_count": 3,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"Downloading: \"https://github.com/facebookresearch/capi/zipball/main\" to /root/.cache/torch/hub/main.zip\n",
"WARNING:utils:Can't import sklearnex. If installed, that speeds up scikit-learn 10-100x\n",
"Downloading: \"https://dl.fbaipublicfiles.com/capi/capi_vitl14_p205.pth\" to /root/.cache/torch/hub/checkpoints/capi_vitl14_p205.pth\n",
"100%|██████████| 900M/900M [00:47<00:00, 20.0MB/s]\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"img = torch.zeros(1, 3, 224, 224)\n",
"global_repr, registers, feature_map = capi_vitl14_p205(img)\n",
"global_repr.shape, registers.shape, feature_map.shape"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "OETucVAkTovr",
"outputId": "b201ea10-1ec9-4fb2-b766-9823a6dae3ec"
},
"execution_count": 4,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(torch.Size([1, 1024]),\n",
" torch.Size([1, 16, 1024]),\n",
" torch.Size([1, 16, 16, 1024]))"
]
},
"metadata": {},
"execution_count": 4
}
]
},
{
"cell_type": "code",
"source": [
"%load_ext rich"
],
"metadata": {
"id": "sTZh4MX4mam9"
},
"execution_count": 6,
"outputs": []
},
{
"cell_type": "code",
"source": [
"capi_vitl14_p205"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 981
},
"id": "amKIooVSmbtH",
"outputId": "2272eb5a-433a-4974-dae1-bedb51a7b401"
},
"execution_count": 7,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [],
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
]
},
"metadata": {}
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"\n",
"\u001b[1;35mEncoderDecoder\u001b[0m\u001b[1m(\u001b[0m\n",
" \u001b[1m(\u001b[0mencoder\u001b[1m)\u001b[0m: \u001b[1;35mTransformer\u001b[0m\u001b[1m(\u001b[0m\n",
" \u001b[1m(\u001b[0mblocks\u001b[1m)\u001b[0m: \u001b[1;35mModuleList\u001b[0m\u001b[1m(\u001b[0m\n",
" \u001b[1m(\u001b[0m\u001b[1;36m0\u001b[0m-\u001b[1;36m23\u001b[0m\u001b[1m)\u001b[0m: \u001b[1;36m24\u001b[0m x \u001b[1;35mBlock\u001b[0m\u001b[1m(\u001b[0m\n",
" \u001b[1m(\u001b[0mresidual1\u001b[1m)\u001b[0m: \u001b[1;35mEfficientResidual\u001b[0m\u001b[1m(\u001b[0m\n",
" \u001b[1m(\u001b[0mnorm\u001b[1m)\u001b[0m: \u001b[1;35mRMSNorm\u001b[0m\u001b[1m(\u001b[0m\u001b[1m(\u001b[0m\u001b[1;36m1024\u001b[0m,\u001b[1m)\u001b[0m, \u001b[33meps\u001b[0m=\u001b[1;36m1e\u001b[0m\u001b[1;36m-05\u001b[0m, \u001b[33melementwise_affine\u001b[0m=\u001b[3;92mTrue\u001b[0m\u001b[1m)\u001b[0m\n",
" \u001b[1m(\u001b[0mfn\u001b[1m)\u001b[0m: \u001b[1;35mAttention\u001b[0m\u001b[1m(\u001b[0m\n",
" \u001b[1m(\u001b[0mq_proj\u001b[1m)\u001b[0m: \u001b[1;35mLinear\u001b[0m\u001b[1m(\u001b[0m\u001b[33min_features\u001b[0m=\u001b[1;36m1024\u001b[0m, \u001b[33mout_features\u001b[0m=\u001b[1;36m1024\u001b[0m, \u001b[33mbias\u001b[0m=\u001b[3;91mFalse\u001b[0m\u001b[1m)\u001b[0m\n",
" \u001b[1m(\u001b[0mk_proj\u001b[1m)\u001b[0m: \u001b[1;35mLinear\u001b[0m\u001b[1m(\u001b[0m\u001b[33min_features\u001b[0m=\u001b[1;36m1024\u001b[0m, \u001b[33mout_features\u001b[0m=\u001b[1;36m1024\u001b[0m, \u001b[33mbias\u001b[0m=\u001b[3;91mFalse\u001b[0m\u001b[1m)\u001b[0m\n",
" \u001b[1m(\u001b[0mv_proj\u001b[1m)\u001b[0m: \u001b[1;35mLinear\u001b[0m\u001b[1m(\u001b[0m\u001b[33min_features\u001b[0m=\u001b[1;36m1024\u001b[0m, \u001b[33mout_features\u001b[0m=\u001b[1;36m1024\u001b[0m, \u001b[33mbias\u001b[0m=\u001b[3;91mFalse\u001b[0m\u001b[1m)\u001b[0m\n",
" \u001b[1m(\u001b[0mproj\u001b[1m)\u001b[0m: \u001b[1;35mLinear\u001b[0m\u001b[1m(\u001b[0m\u001b[33min_features\u001b[0m=\u001b[1;36m1024\u001b[0m, \u001b[33mout_features\u001b[0m=\u001b[1;36m1024\u001b[0m, \u001b[33mbias\u001b[0m=\u001b[3;91mFalse\u001b[0m\u001b[1m)\u001b[0m\n",
" \u001b[1m(\u001b[0mrope\u001b[1m)\u001b[0m: \u001b[1;35mRope\u001b[0m\u001b[1m(\u001b[0m\u001b[1m)\u001b[0m\n",
" \u001b[1m)\u001b[0m\n",
" \u001b[1m)\u001b[0m\n",
" \u001b[1m(\u001b[0mresidual2\u001b[1m)\u001b[0m: \u001b[1;35mEfficientResidual\u001b[0m\u001b[1m(\u001b[0m\n",
" \u001b[1m(\u001b[0mnorm\u001b[1m)\u001b[0m: \u001b[1;35mRMSNorm\u001b[0m\u001b[1m(\u001b[0m\u001b[1m(\u001b[0m\u001b[1;36m1024\u001b[0m,\u001b[1m)\u001b[0m, \u001b[33meps\u001b[0m=\u001b[1;36m1e\u001b[0m\u001b[1;36m-05\u001b[0m, \u001b[33melementwise_affine\u001b[0m=\u001b[3;92mTrue\u001b[0m\u001b[1m)\u001b[0m\n",
" \u001b[1m(\u001b[0mfn\u001b[1m)\u001b[0m: \u001b[1;35mMlp\u001b[0m\u001b[1m(\u001b[0m\n",
" \u001b[1m(\u001b[0mfc1\u001b[1m)\u001b[0m: \u001b[1;35mLinear\u001b[0m\u001b[1m(\u001b[0m\u001b[33min_features\u001b[0m=\u001b[1;36m1024\u001b[0m, \u001b[33mout_features\u001b[0m=\u001b[1;36m4096\u001b[0m, \u001b[33mbias\u001b[0m=\u001b[3;91mFalse\u001b[0m\u001b[1m)\u001b[0m\n",
" \u001b[1m(\u001b[0mact\u001b[1m)\u001b[0m: \u001b[1;35mGELU\u001b[0m\u001b[1m(\u001b[0m\u001b[33mapproximate\u001b[0m=\u001b[32m'none'\u001b[0m\u001b[1m)\u001b[0m\n",
" \u001b[1m(\u001b[0mfc2\u001b[1m)\u001b[0m: \u001b[1;35mLinear\u001b[0m\u001b[1m(\u001b[0m\u001b[33min_features\u001b[0m=\u001b[1;36m4096\u001b[0m, \u001b[33mout_features\u001b[0m=\u001b[1;36m1024\u001b[0m, \u001b[33mbias\u001b[0m=\u001b[3;91mFalse\u001b[0m\u001b[1m)\u001b[0m\n",
" \u001b[1m)\u001b[0m\n",
" \u001b[1m)\u001b[0m\n",
" \u001b[1m)\u001b[0m\n",
" \u001b[1m)\u001b[0m\n",
" \u001b[1m)\u001b[0m\n",
" \u001b[1m(\u001b[0mdecoder\u001b[1m)\u001b[0m: \u001b[1;35mTransformer\u001b[0m\u001b[1m(\u001b[0m\n",
" \u001b[1m(\u001b[0mblocks\u001b[1m)\u001b[0m: \u001b[1;35mModuleList\u001b[0m\u001b[1m(\u001b[0m\n",
" \u001b[1m(\u001b[0m\u001b[1;36m0\u001b[0m-\u001b[1;36m11\u001b[0m\u001b[1m)\u001b[0m: \u001b[1;36m12\u001b[0m x \u001b[1;35mBlock\u001b[0m\u001b[1m(\u001b[0m\n",
" \u001b[1m(\u001b[0mresidual1\u001b[1m)\u001b[0m: \u001b[1;35mEfficientResidual\u001b[0m\u001b[1m(\u001b[0m\n",
" \u001b[1m(\u001b[0mnorm\u001b[1m)\u001b[0m: \u001b[1;35mRMSNorm\u001b[0m\u001b[1m(\u001b[0m\u001b[1m(\u001b[0m\u001b[1;36m1024\u001b[0m,\u001b[1m)\u001b[0m, \u001b[33meps\u001b[0m=\u001b[1;36m1e\u001b[0m\u001b[1;36m-05\u001b[0m, \u001b[33melementwise_affine\u001b[0m=\u001b[3;92mTrue\u001b[0m\u001b[1m)\u001b[0m\n",
" \u001b[1m(\u001b[0mfn\u001b[1m)\u001b[0m: \u001b[1;35mAttention\u001b[0m\u001b[1m(\u001b[0m\n",
" \u001b[1m(\u001b[0mq_proj\u001b[1m)\u001b[0m: \u001b[1;35mLinear\u001b[0m\u001b[1m(\u001b[0m\u001b[33min_features\u001b[0m=\u001b[1;36m1024\u001b[0m, \u001b[33mout_features\u001b[0m=\u001b[1;36m1024\u001b[0m, \u001b[33mbias\u001b[0m=\u001b[3;91mFalse\u001b[0m\u001b[1m)\u001b[0m\n",
" \u001b[1m(\u001b[0mk_proj\u001b[1m)\u001b[0m: \u001b[1;35mLinear\u001b[0m\u001b[1m(\u001b[0m\u001b[33min_features\u001b[0m=\u001b[1;36m1024\u001b[0m, \u001b[33mout_features\u001b[0m=\u001b[1;36m1024\u001b[0m, \u001b[33mbias\u001b[0m=\u001b[3;91mFalse\u001b[0m\u001b[1m)\u001b[0m\n",
" \u001b[1m(\u001b[0mv_proj\u001b[1m)\u001b[0m: \u001b[1;35mLinear\u001b[0m\u001b[1m(\u001b[0m\u001b[33min_features\u001b[0m=\u001b[1;36m1024\u001b[0m, \u001b[33mout_features\u001b[0m=\u001b[1;36m1024\u001b[0m, \u001b[33mbias\u001b[0m=\u001b[3;91mFalse\u001b[0m\u001b[1m)\u001b[0m\n",
" \u001b[1m(\u001b[0mproj\u001b[1m)\u001b[0m: \u001b[1;35mLinear\u001b[0m\u001b[1m(\u001b[0m\u001b[33min_features\u001b[0m=\u001b[1;36m1024\u001b[0m, \u001b[33mout_features\u001b[0m=\u001b[1;36m1024\u001b[0m, \u001b[33mbias\u001b[0m=\u001b[3;91mFalse\u001b[0m\u001b[1m)\u001b[0m\n",
" \u001b[1m(\u001b[0mrope\u001b[1m)\u001b[0m: \u001b[1;35mRope\u001b[0m\u001b[1m(\u001b[0m\u001b[1m)\u001b[0m\n",
" \u001b[1m)\u001b[0m\n",
" \u001b[1m)\u001b[0m\n",
" \u001b[1m(\u001b[0mresidual2\u001b[1m)\u001b[0m: \u001b[1;35mEfficientResidual\u001b[0m\u001b[1m(\u001b[0m\n",
" \u001b[1m(\u001b[0mnorm\u001b[1m)\u001b[0m: \u001b[1;35mRMSNorm\u001b[0m\u001b[1m(\u001b[0m\u001b[1m(\u001b[0m\u001b[1;36m1024\u001b[0m,\u001b[1m)\u001b[0m, \u001b[33meps\u001b[0m=\u001b[1;36m1e\u001b[0m\u001b[1;36m-05\u001b[0m, \u001b[33melementwise_affine\u001b[0m=\u001b[3;92mTrue\u001b[0m\u001b[1m)\u001b[0m\n",
" \u001b[1m(\u001b[0mfn\u001b[1m)\u001b[0m: \u001b[1;35mMlp\u001b[0m\u001b[1m(\u001b[0m\n",
" \u001b[1m(\u001b[0mfc1\u001b[1m)\u001b[0m: \u001b[1;35mLinear\u001b[0m\u001b[1m(\u001b[0m\u001b[33min_features\u001b[0m=\u001b[1;36m1024\u001b[0m, \u001b[33mout_features\u001b[0m=\u001b[1;36m4096\u001b[0m, \u001b[33mbias\u001b[0m=\u001b[3;91mFalse\u001b[0m\u001b[1m)\u001b[0m\n",
" \u001b[1m(\u001b[0mact\u001b[1m)\u001b[0m: \u001b[1;35mGELU\u001b[0m\u001b[1m(\u001b[0m\u001b[33mapproximate\u001b[0m=\u001b[32m'none'\u001b[0m\u001b[1m)\u001b[0m\n",
" \u001b[1m(\u001b[0mfc2\u001b[1m)\u001b[0m: \u001b[1;35mLinear\u001b[0m\u001b[1m(\u001b[0m\u001b[33min_features\u001b[0m=\u001b[1;36m4096\u001b[0m, \u001b[33mout_features\u001b[0m=\u001b[1;36m1024\u001b[0m, \u001b[33mbias\u001b[0m=\u001b[3;91mFalse\u001b[0m\u001b[1m)\u001b[0m\n",
" \u001b[1m)\u001b[0m\n",
" \u001b[1m)\u001b[0m\n",
" \u001b[1m)\u001b[0m\n",
" \u001b[1m)\u001b[0m\n",
" \u001b[1m)\u001b[0m\n",
" \u001b[1m(\u001b[0mpatch_embed\u001b[1m)\u001b[0m: \u001b[1;35mConv2d\u001b[0m\u001b[1m(\u001b[0m\u001b[1;36m3\u001b[0m, \u001b[1;36m1024\u001b[0m, \u001b[33mkernel_size\u001b[0m=\u001b[1m(\u001b[0m\u001b[1;36m14\u001b[0m, \u001b[1;36m14\u001b[0m\u001b[1m)\u001b[0m, \u001b[33mstride\u001b[0m=\u001b[1m(\u001b[0m\u001b[1;36m14\u001b[0m, \u001b[1;36m14\u001b[0m\u001b[1m)\u001b[0m\u001b[1m)\u001b[0m\n",
" \u001b[1m(\u001b[0menc_norm\u001b[1m)\u001b[0m: \u001b[1;35mRMSNorm\u001b[0m\u001b[1m(\u001b[0m\u001b[1m(\u001b[0m\u001b[1;36m1024\u001b[0m,\u001b[1m)\u001b[0m, \u001b[33meps\u001b[0m=\u001b[1;36m1e\u001b[0m\u001b[1;36m-05\u001b[0m, \u001b[33melementwise_affine\u001b[0m=\u001b[3;91mFalse\u001b[0m\u001b[1m)\u001b[0m\n",
" \u001b[1m(\u001b[0mdec_norm\u001b[1m)\u001b[0m: \u001b[1;35mRMSNorm\u001b[0m\u001b[1m(\u001b[0m\u001b[1m(\u001b[0m\u001b[1;36m1024\u001b[0m,\u001b[1m)\u001b[0m, \u001b[33meps\u001b[0m=\u001b[1;36m1e\u001b[0m\u001b[1;36m-05\u001b[0m, \u001b[33melementwise_affine\u001b[0m=\u001b[3;91mFalse\u001b[0m\u001b[1m)\u001b[0m\n",
"\u001b[1m)\u001b[0m"
]
},
"metadata": {},
"execution_count": 7
}
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment