Created
December 5, 2019 13:53
-
-
Save mdouze/94bd7a56d912a06ac4719c50fa5b01ac to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import faiss\n", | |
"import numpy as np" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Convert a PQ codec to pure numpy" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class NumpyPQCodec: \n", | |
" \n", | |
" def __init__(self, index): \n", | |
" \n", | |
" assert index.is_trained\n", | |
" \n", | |
" # handle the pretransform\n", | |
" if isinstance(index, faiss.IndexPreTransform): \n", | |
" vt = faiss.downcast_VectorTransform(index.chain.at(0)) \n", | |
" assert isinstance(vt, faiss.LinearTransform)\n", | |
" b = faiss.vector_to_array(vt.b)\n", | |
" A = faiss.vector_to_array(vt.A).reshape(vt.d_out, vt.d_in)\n", | |
" self.pre = (A, b)\n", | |
" index = faiss.downcast_index(index.index)\n", | |
" else: \n", | |
" self.pre = None\n", | |
" \n", | |
" # extract the PQ centroids\n", | |
" assert isinstance(index, faiss.IndexPQ)\n", | |
" pq = index.pq\n", | |
" cen = faiss.vector_to_array(pq.centroids)\n", | |
" cen = cen.reshape(pq.M, pq.ksub, pq.dsub)\n", | |
" assert pq.nbits == 8\n", | |
" self.centroids = cen\n", | |
" self.norm2_centroids = (cen ** 2).sum(axis=2)\n", | |
" \n", | |
" def encode(self, x): \n", | |
" if self.pre is not None: \n", | |
" A, b = self.pre\n", | |
" x = x @ A.T\n", | |
" if b.size > 0: \n", | |
" x += b\n", | |
" \n", | |
" n, d = x.shape\n", | |
" cen = self.centroids\n", | |
" M, ksub, dsub = cen.shape\n", | |
" codes = np.empty((n, M), dtype='uint8')\n", | |
" # maybe possible to vectorize this loop...\n", | |
" for m in range(M): \n", | |
" # compute all per-centroid distances, ignoring the ||x||^2 term\n", | |
" xslice = x[:, m * dsub:(m + 1) * dsub]\n", | |
" dis = self.norm2_centroids[m] - 2 * xslice @ cen[m].T \n", | |
" codes[:, m] = dis.argmin(axis=1)\n", | |
" return codes\n", | |
"\n", | |
" def decode(self, codes): \n", | |
" n, MM = codes.shape\n", | |
" cen = self.centroids\n", | |
" M, ksub, dsub = cen.shape\n", | |
" assert MM == M\n", | |
" x = np.empty((n, M * dsub), dtype='float32')\n", | |
" for m in range(M): \n", | |
" xslice = cen[m, codes[:, m]]\n", | |
" x[:, m * dsub:(m + 1) * dsub] = xslice\n", | |
" if self.pre is not None: \n", | |
" A, b = self.pre\n", | |
" if b.size > 0: \n", | |
" x -= b \n", | |
" x = x @ A \n", | |
" return x \n", | |
" " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# codec = faiss.index_factory(128, \"OPQ32_128,PQ32\")\n", | |
"codec = faiss.index_factory(128, \"PCAR64,PQ32\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# train the codec\n", | |
"xt = faiss.rand((10000, 128))\n", | |
"codec.train(xt)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# convetert to a pure numpy codec\n", | |
"ncodec = NumpyPQCodec(codec)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"True" | |
] | |
}, | |
"execution_count": 12, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# test encoding\n", | |
"xb = faiss.rand((100, 128), 1234)\n", | |
"(ncodec.encode(xb) == codec.sa_encode(xb)).all()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"True" | |
] | |
}, | |
"execution_count": 13, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# test decoding\n", | |
"codes = faiss.randint((100, 32), vmax=256, seed=12345).astype('uint8')\n", | |
"np.allclose(codec.sa_decode(codes), ncodec.decode(codes))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Same for pytorch" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import torch" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 64, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class TorchPQCodec(NumpyPQCodec): \n", | |
" \n", | |
" def __init__(self, index, dev): \n", | |
" NumpyPQCodec.__init__(self, index)\n", | |
" # just move everything to torch on the given device\n", | |
" if self.pre: \n", | |
" A, b = self.pre\n", | |
" self.pre_torch = (torch.from_numpy(A).to(dev), \n", | |
" torch.from_numpy(b).to(dev))\n", | |
" else: \n", | |
" self.pre_torch = None \n", | |
" self.centroids_torch = torch.from_numpy(self.centroids).to(dev)\n", | |
" self.norm2_centroids_torch = torch.from_numpy(self.norm2_centroids).to(dev)\n", | |
" \n", | |
" def encode(self, x): \n", | |
" if self.pre_torch is not None: \n", | |
" A, b = self.pre_torch\n", | |
" x = x @ A.t()\n", | |
" if b.numel() > 0: \n", | |
" x += b\n", | |
" \n", | |
" n, d = x.shape\n", | |
" cen = self.centroids_torch\n", | |
" M, ksub, dsub = cen.shape\n", | |
" codes = torch.empty((n, M), dtype=torch.uint8, device=x.device)\n", | |
" # maybe possible to vectorize this loop...\n", | |
" for m in range(M): \n", | |
" # compute all per-centroid distances, ignoring the ||x||^2 term\n", | |
" xslice = x[:, m * dsub:(m + 1) * dsub]\n", | |
" dis = self.norm2_centroids_torch[m] - 2 * xslice @ cen[m].t()\n", | |
" codes[:, m] = dis.argmin(dim=1)\n", | |
" return codes\n", | |
" \n", | |
" def decode(self, codes): \n", | |
" n, MM = codes.shape\n", | |
" cen = self.centroids_torch\n", | |
" M, ksub, dsub = cen.shape\n", | |
" assert MM == M\n", | |
" x = torch.empty((n, M * dsub), dtype=torch.float32, device=codes.device)\n", | |
" for m in range(M): \n", | |
" xslice = cen[m, codes[:, m].long()]\n", | |
" x[:, m * dsub:(m + 1) * dsub] = xslice\n", | |
" \n", | |
" if self.pre is not None: \n", | |
" A, b = self.pre_torch\n", | |
" if b.numel() > 0: \n", | |
" x -= b \n", | |
" x = x @ A \n", | |
" return x \n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 82, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"True" | |
] | |
}, | |
"execution_count": 82, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# test encode \n", | |
"dev = torch.device('cuda:0')\n", | |
"tcodec = TorchPQCodec(codec, dev)\n", | |
"\n", | |
"xb = torch.rand(1000, 128).to(dev)\n", | |
"(tcodec.encode(xb).cpu().numpy() == codec.sa_encode(xb.cpu().numpy())).all()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 83, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"True" | |
] | |
}, | |
"execution_count": 83, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# test decode\n", | |
"codes = torch.randint(256, size=(1000, 32), dtype=torch.uint8).to(dev)\n", | |
"\n", | |
"np.allclose(\n", | |
" tcodec.decode(codes).cpu().numpy(),\n", | |
" codec.sa_decode(codes.cpu().numpy()),\n", | |
" atol=1e-6\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"anp_metadata": { | |
"path": "notebooks/PQ_codec_pytorch.ipynb" | |
}, | |
"bento_stylesheets": { | |
"bento/extensions/flow/main.css": true, | |
"bento/extensions/kernel_selector/main.css": true, | |
"bento/extensions/kernel_ui/main.css": true, | |
"bento/extensions/new_kernel/main.css": true, | |
"bento/extensions/system_usage/main.css": true, | |
"bento/extensions/theme/main.css": true | |
}, | |
"disseminate_notebook_info": { | |
"backup_notebook_id": "2541131109543807" | |
}, | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.7.4" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment