Skip to content

Instantly share code, notes, and snippets.

@mdouze
Created December 5, 2019 13:53
Show Gist options
  • Save mdouze/94bd7a56d912a06ac4719c50fa5b01ac to your computer and use it in GitHub Desktop.
Save mdouze/94bd7a56d912a06ac4719c50fa5b01ac to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"import faiss\n",
"import numpy as np"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Convert a PQ codec to pure numpy"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"class NumpyPQCodec: \n",
" \n",
" def __init__(self, index): \n",
" \n",
" assert index.is_trained\n",
" \n",
" # handle the pretransform\n",
" if isinstance(index, faiss.IndexPreTransform): \n",
" vt = faiss.downcast_VectorTransform(index.chain.at(0)) \n",
" assert isinstance(vt, faiss.LinearTransform)\n",
" b = faiss.vector_to_array(vt.b)\n",
" A = faiss.vector_to_array(vt.A).reshape(vt.d_out, vt.d_in)\n",
" self.pre = (A, b)\n",
" index = faiss.downcast_index(index.index)\n",
" else: \n",
" self.pre = None\n",
" \n",
" # extract the PQ centroids\n",
" assert isinstance(index, faiss.IndexPQ)\n",
" pq = index.pq\n",
" cen = faiss.vector_to_array(pq.centroids)\n",
" cen = cen.reshape(pq.M, pq.ksub, pq.dsub)\n",
" assert pq.nbits == 8\n",
" self.centroids = cen\n",
" self.norm2_centroids = (cen ** 2).sum(axis=2)\n",
" \n",
" def encode(self, x): \n",
" if self.pre is not None: \n",
" A, b = self.pre\n",
" x = x @ A.T\n",
" if b.size > 0: \n",
" x += b\n",
" \n",
" n, d = x.shape\n",
" cen = self.centroids\n",
" M, ksub, dsub = cen.shape\n",
" codes = np.empty((n, M), dtype='uint8')\n",
" # maybe possible to vectorize this loop...\n",
" for m in range(M): \n",
" # compute all per-centroid distances, ignoring the ||x||^2 term\n",
" xslice = x[:, m * dsub:(m + 1) * dsub]\n",
" dis = self.norm2_centroids[m] - 2 * xslice @ cen[m].T \n",
" codes[:, m] = dis.argmin(axis=1)\n",
" return codes\n",
"\n",
" def decode(self, codes): \n",
" n, MM = codes.shape\n",
" cen = self.centroids\n",
" M, ksub, dsub = cen.shape\n",
" assert MM == M\n",
" x = np.empty((n, M * dsub), dtype='float32')\n",
" for m in range(M): \n",
" xslice = cen[m, codes[:, m]]\n",
" x[:, m * dsub:(m + 1) * dsub] = xslice\n",
" if self.pre is not None: \n",
" A, b = self.pre\n",
" if b.size > 0: \n",
" x -= b \n",
" x = x @ A \n",
" return x \n",
" "
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"# codec = faiss.index_factory(128, \"OPQ32_128,PQ32\")\n",
"codec = faiss.index_factory(128, \"PCAR64,PQ32\")"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"# train the codec\n",
"xt = faiss.rand((10000, 128))\n",
"codec.train(xt)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"# convetert to a pure numpy codec\n",
"ncodec = NumpyPQCodec(codec)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# test encoding\n",
"xb = faiss.rand((100, 128), 1234)\n",
"(ncodec.encode(xb) == codec.sa_encode(xb)).all()"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# test decoding\n",
"codes = faiss.randint((100, 32), vmax=256, seed=12345).astype('uint8')\n",
"np.allclose(codec.sa_decode(codes), ncodec.decode(codes))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Same for pytorch"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"import torch"
]
},
{
"cell_type": "code",
"execution_count": 64,
"metadata": {},
"outputs": [],
"source": [
"class TorchPQCodec(NumpyPQCodec): \n",
" \n",
" def __init__(self, index, dev): \n",
" NumpyPQCodec.__init__(self, index)\n",
" # just move everything to torch on the given device\n",
" if self.pre: \n",
" A, b = self.pre\n",
" self.pre_torch = (torch.from_numpy(A).to(dev), \n",
" torch.from_numpy(b).to(dev))\n",
" else: \n",
" self.pre_torch = None \n",
" self.centroids_torch = torch.from_numpy(self.centroids).to(dev)\n",
" self.norm2_centroids_torch = torch.from_numpy(self.norm2_centroids).to(dev)\n",
" \n",
" def encode(self, x): \n",
" if self.pre_torch is not None: \n",
" A, b = self.pre_torch\n",
" x = x @ A.t()\n",
" if b.numel() > 0: \n",
" x += b\n",
" \n",
" n, d = x.shape\n",
" cen = self.centroids_torch\n",
" M, ksub, dsub = cen.shape\n",
" codes = torch.empty((n, M), dtype=torch.uint8, device=x.device)\n",
" # maybe possible to vectorize this loop...\n",
" for m in range(M): \n",
" # compute all per-centroid distances, ignoring the ||x||^2 term\n",
" xslice = x[:, m * dsub:(m + 1) * dsub]\n",
" dis = self.norm2_centroids_torch[m] - 2 * xslice @ cen[m].t()\n",
" codes[:, m] = dis.argmin(dim=1)\n",
" return codes\n",
" \n",
" def decode(self, codes): \n",
" n, MM = codes.shape\n",
" cen = self.centroids_torch\n",
" M, ksub, dsub = cen.shape\n",
" assert MM == M\n",
" x = torch.empty((n, M * dsub), dtype=torch.float32, device=codes.device)\n",
" for m in range(M): \n",
" xslice = cen[m, codes[:, m].long()]\n",
" x[:, m * dsub:(m + 1) * dsub] = xslice\n",
" \n",
" if self.pre is not None: \n",
" A, b = self.pre_torch\n",
" if b.numel() > 0: \n",
" x -= b \n",
" x = x @ A \n",
" return x \n"
]
},
{
"cell_type": "code",
"execution_count": 82,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 82,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# test encode \n",
"dev = torch.device('cuda:0')\n",
"tcodec = TorchPQCodec(codec, dev)\n",
"\n",
"xb = torch.rand(1000, 128).to(dev)\n",
"(tcodec.encode(xb).cpu().numpy() == codec.sa_encode(xb.cpu().numpy())).all()"
]
},
{
"cell_type": "code",
"execution_count": 83,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 83,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# test decode\n",
"codes = torch.randint(256, size=(1000, 32), dtype=torch.uint8).to(dev)\n",
"\n",
"np.allclose(\n",
" tcodec.decode(codes).cpu().numpy(),\n",
" codec.sa_decode(codes.cpu().numpy()),\n",
" atol=1e-6\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"anp_metadata": {
"path": "notebooks/PQ_codec_pytorch.ipynb"
},
"bento_stylesheets": {
"bento/extensions/flow/main.css": true,
"bento/extensions/kernel_selector/main.css": true,
"bento/extensions/kernel_ui/main.css": true,
"bento/extensions/new_kernel/main.css": true,
"bento/extensions/system_usage/main.css": true,
"bento/extensions/theme/main.css": true
},
"disseminate_notebook_info": {
"backup_notebook_id": "2541131109543807"
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment