Skip to content

Instantly share code, notes, and snippets.

@justheuristic
Last active January 21, 2019 18:03
Show Gist options
  • Save justheuristic/bef3b968a2fafcc1e8eb98021e51fd2e to your computer and use it in GitHub Desktop.
Save justheuristic/bef3b968a2fafcc1e8eb98021e51fd2e to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Minifox example\n",
"![Img](https://i.imgur.com/IkU0BxB.png?2)\n",
"\n",
"__Patch notes:__\n",
"* теперь она умеет обучаться батчами\n",
"* компоненты сортируются по MSE от наибольшей к наименьшей\n",
"* заточил код под gpu (~x1.8 на tesla m40 против старой версии, разница с cpu зависит от batch size и моделек)\n",
" * использование: ```MinFoxSolver(..., device='gpu')```\n",
"* можно подсовывать свои модельки через ```MinFoxSolver(..., make_predictor=lambda n_components: make_my_keras_model()```\n",
"* двойные градиенты пока только всё замедляют в десятки раз даже для 2-слойного mlp :(\n",
"* если в данных на самом деле нет нужного числа ортогональных компонент, оно таки может иногда порождать несколько скоррелированных из-за float32 точности на хвосте Грамма-Шмидта, при этом все реальные компоненты тоже найдёт.\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from minifox import MinFoxSolver, orthogonalize_rows, orthogonalize_columns\n",
"import matplotlib.pyplot as plt\n",
"%matplotlib inline\n",
"import numpy as np\n",
"\n",
"num_samples, true_dim, projected_dim, components = 200000, 200, 500, 100\n",
"\n",
"M = np.random.randn(num_samples, true_dim).astype('float32')\n",
"A = M.dot(np.random.randn(true_dim, projected_dim).astype('float32'))\n",
"B = M[:, :-components].dot(np.random.randn(true_dim - components, projected_dim).astype('float32'))"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"step 10; loss=18.3006; delta=inf\n",
"step 20; loss=1.3096; delta=16.9910\n",
"step 30; loss=0.5271; delta=0.7825\n",
"step 40; loss=0.5125; delta=0.0146\n",
"step 50; loss=0.5185; delta=0.0060\n",
"step 60; loss=0.5297; delta=0.0112\n",
"step 70; loss=0.5546; delta=0.0249\n",
"step 80; loss=0.5948; delta=0.0402\n",
"step 90; loss=0.6491; delta=0.0543\n",
"step 100; loss=0.7057; delta=0.0566\n",
"step 110; loss=0.7649; delta=0.0592\n",
"step 120; loss=0.8176; delta=0.0527\n",
"step 130; loss=0.8636; delta=0.0461\n",
"step 140; loss=0.9005; delta=0.0369\n",
"step 150; loss=0.9232; delta=0.0226\n",
"step 160; loss=0.9523; delta=0.0291\n",
"step 170; loss=0.9670; delta=0.0147\n",
"step 180; loss=0.9778; delta=0.0107\n",
"step 190; loss=0.9825; delta=0.0047\n",
"step 200; loss=0.9765; delta=0.0060\n",
"step 210; loss=0.9916; delta=0.0150\n",
"step 220; loss=0.9949; delta=0.0034\n",
"step 230; loss=0.9977; delta=0.0028\n",
"step 240; loss=0.9985; delta=0.0007\n",
"step 250; loss=0.9879; delta=0.0105\n",
"step 260; loss=1.0000; delta=0.0121\n",
"step 270; loss=0.9993; delta=0.0007\n",
"step 280; loss=1.0006; delta=0.0013\n",
"step 290; loss=1.0000; delta=0.0006\n",
"step 300; loss=0.9881; delta=0.0119\n",
"step 310; loss=1.0000; delta=0.0119\n",
"step 320; loss=0.9998; delta=0.0002\n",
"step 330; loss=1.0007; delta=0.0010\n",
"step 340; loss=1.0000; delta=0.0008\n",
"step 350; loss=0.9872; delta=0.0127\n",
"step 360; loss=1.0005; delta=0.0132\n",
"step 370; loss=0.9999; delta=0.0006\n",
"step 380; loss=1.0014; delta=0.0016\n",
"step 390; loss=1.0001; delta=0.0014\n",
"step 400; loss=0.9869; delta=0.0132\n",
"step 410; loss=1.0004; delta=0.0135\n",
"step 420; loss=0.9998; delta=0.0006\n",
"step 430; loss=1.0010; delta=0.0012\n",
"step 440; loss=0.9998; delta=0.0012\n",
"step 450; loss=0.9868; delta=0.0130\n",
"step 460; loss=1.0004; delta=0.0136\n",
"step 470; loss=1.0004; delta=0.0001\n",
"Done: reached target tolerance\n",
"Ordering components by loss values...\n",
"Training finished.\n",
"CPU times: user 2min, sys: 26.4 s, total: 2min 26s\n",
"Wall time: 2min 27s\n"
]
}
],
"source": [
"%%time\n",
"fox = MinFoxSolver(n_components=components, p=projected_dim, device='gpu')\n",
"fox.fit(A, B, batch_size=8192,\n",
" max_iters=1000, tolerance=1e-4,\n",
" verbose=True, report_every=10)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.image.AxesImage at 0x7fcfbc169358>"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAP4AAAD/CAYAAADRymv0AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAHrJJREFUeJzt3VtsXFl6H/rfKl6KV7FbFMWLqPu11VK3rJ7pVs7gIJPYQQYJ4EkeEthAADvxYxIbcRB4nBfDj34YBAOc4ARGHGMQGD6O7eTMPCTxxJiMAwRQq6c16lZLrZYo6kbx0hLZongtkuLKA4saTbfUoiRSVSS/H1AQa7PI+rDFf6219/pqV8o5CyFsLoVKFxBCePki+CFsQhH8EDahCH4Im1AEP4RNKIIfwib0QsFPKX0jpXQppXQ5pfRbq1VUCGFtpeddx08pFXAZP49BvIdfyjlfWr3yQghr4UVG/LdxJed8I+c8j/8P31ydskIIa+lFgr8Dtx65P1DeFkKocrVr/QQppegJDqFCcs7pcdtfJPi3seuR+73lbV/Q4lXbtKjxQMlhQ/4vD9b+NecF/Qhfr3ANz+JH1le9RM2r7Xr5tuyvnvjIF0nfeziQUtqNIfwSfvlxD3xgv9ckh33irOyeGbMaPFAjx4piCKtkT/m2bA2Cn3N+kFL6Z/iBpXMFf5Bz/vhxj51X56rdpjVJsre8b0i3W3aa0fS8JYQQntMLzbdzzv8dh5/2uAUHXLXLkG5ved9X/NhHjrmjo4qDv6fSBTyjPZUu4DnsqXQBz2FPpQtYFS9pnr3PA7VmNRjS7YLXzal30BX79Wl1/+WU8Uz2VLqAZ7Sn0gU8hz2VLuA57Kl0AavipZ5hW1Drlp3u6HBAnyMuGdVuVoMJW15mKSFsai81+FnBtGbTmo0ZM2arLOk0omDRPa/EC0AIL0HF1tRGdJrVoNOIXgO6DfnYaxH8EF6CigV/whYTtihY1G1IqwlbjZnQalqTKc14bO9BCOEFVXwR/TOvuuioq/ZrM+4rfqzHoCQa/kJYKxUP/oQt+u13w25FJftd1WlEsyl15iSLlS4xhA2navpmpzTrc8CEVjUeeNsZA3rdtMusxkqXF8KGUvERf9ly8M858TD4e13TYLbSpYWw4VRN8EkW1Sgpum2Hc06YU++wT+zXp8VEpQsMYcOomqn+snl1btplRKcjLjnqohGdZjSa1Frp8kLYEKou+FnBjCYzmoxqd0eHRQVdhtWZN2ZrrPWH8IKqLviPGtZlWpMuw3a5qduQi45G8EN4QVUd/OUmnzrzegxqMq3dqGlNJrVEk08Iz6mKTu492ZitLjjmmr1e9ZmTzuoxWOmyQli3qnrEX3Zfm/vaTGnSacROt4xrM6TbnHrz6uJKPiE8g3WVlinNrjjovOPqzDvltD2uqzdX6dJCWFfWVfAntbjskPOOqzf3MPhFpUqXFsK6si6m+j+1dCKvpOimXQoWzal31EV3bTOsK9b6Q1iBdTXiL5tT74bdTjulpOiYjxxyWYvJSpcWwrqwzkb8JVnBrEazGt21zZBuC2rtcFuDWaPaY60/hC+xLoP/qCHdJrXoMWi3G7oMu+D1CH4IX2LdB39Sq0mtikp6DGowq92oWQ0mtcQxfwiPse6Dv2zMVh85pt2odqM6jbjskL4IfghfsGGCv9zks3wBzx6Dxmz1qe1KiubUR5NPCGUbLgkTWl12yEeOaTDrlNN2u6HOfKVLC6FqbLjgT2p1xSEXvK6o5B3v2uVmubsvl28hbG4bZqr/eXPqXbfHooIFtd7woTs6DOoxpaXS5YVQURtuxF9WUnTdHu96x7w6b/jQAX3R5BOCDTziZwUlDUqK7uhwy04LavUa0GTaXdtirT9sWhs2+I8a1OO+LXoN2Kdfl2HnHY/gh01rEwQ/PWzyaTKj25A687a5a0Gt+7ZEk0/YdDbsMf7j3NXuvONu2Wm7T51wTpfhSpcVwku3CUb8n1pu8plTr9uQDnceXsm3pKikGE0+YVPYlH/l923xicM+9pom0045bZebai1UurQQXopNGfzlJp9Ljigq+Tk/0WtAUan8IZ3R5BM2tqcGP6XUm1L6YUrpQkrpfErp18vbX00p/SCl9ElK6S9SSm1rX+7qKim6Zq/3fNUDNU4454A+TaYrXVoIa2olI/4CfjPn/Dr+Gv5pSukIvoW/zDkfxg/x22tX5tqY1eCavc5426KCN33ggD7NpipdWghr6qnBzzkP55zPlb+exMfoxTfx3fLDvou/t1ZFrpWsYE7RlGYjOl2z17w6u92wz1Wt7le6xBDWxDOd1U8p7cEJnEZnznmEpReHlNL2Va/uJVlUcNsO97xitxsO6DOh1YLaaPIJG9KKg59SasGf4TdyzpMppc+fAfuSM2I/euTrPeVbNUmmtJjSosWkLsNqPLDNXQ/UGNcWTT5hHbhevj3dioKfUqq1FPr/mHP+XnnzSEqpM+c8klLqwqdP/g1fX1Ex1eCODh94U4c7ugzrNuSSIxH8sA7s8bOD6l898ZErXc77D7iYc/7OI9u+j18tf/0r+N7nf2g9mrDFTbuN6FSw+PBSXq8a02i6vNwXwvqWcv7yNeuU0tfwv3DeT69k8a9xBv8JO3ED/zDnfO8xP5/5nVUue+21uq/TiHajmkyrN+eavfrts6Cu0uWFsAK/K+f82I+TfupUP+f8v1HzhG//wouUVc2WP6J7zFbveNfrLigpum2HLHmgRnxEd1ivNlWv/vOY0ajfPjMaZclb3jesy4Be05orXV4Iz2VTtuw+i1kN+u3zvrdkyUln7dOv0UylSwvhucWI/xRZwbx6iwqGdbnskDn19unXZtwdHbHWH9adGPFXaFHBgF5nvG1Si0MuO+qiNuOVLi2EZxYj/gplBdOaTWvWZtz2ctvC8r/j2mLkD+tGBP85fGq7OfW2+1SPQd2GfOy1CH5YNyL4z2F5qQ96DGozrt2ocW1mNJrWJJb6QjWLY/wXcM8rLjrqioNaTfiq9/QaUONBpUsL4UtF8F/AhC2uOqDfPkUlR1zSZVijGbXmo703VK2Y6q+CaU2u2m9akyR7y/uGdLtlpxlNlS4vhC+IEX8VzGh01X5nnZRkX/Weffo1mK10aSE8VgR/FWQFD9Sa1WBItwteN6feQVfs1xdX8glVJ6b6q2hBrVt2uqPDAX2OuGRUu1kNsdQXqkoEfxU92uQzZsyYrbKk04iCRfe8Ei8AoSpE8NfIiE6zGnQa0WsgmnxCVYngr5HlJp+CRd2GtJqw1ZgJraY1mdIsmnxCpcTJvTX2mVdddNRV+7UZ9xU/1mNQik/rCRUUwV9jE7bot98NuxWV7HdVpxHNptSZiyafUBEx1X9JpjTrc8CEVjUeeNsZA3rdtMusxkqXFzaZGPFfkuXgn3PiYfD3uhZNPqEiIvgvTbKo5uEFO885YU69wz6xX58WE5UuMGwiMdV/yebVuWmXEZ2OuOSoi0Z0mtEYH9oRXpoI/kuWFcxoMqPJqHZ3dFhU0GVYnXljtsZaf1hzEfwKGtZlWpMuw3a5qduQi45G8MOai+BX0HKTT515PQY1mdZu1LQmk1qiySesmTi5VwXGbHXBMdfs9arPnHRWj8FKlxU2sBjxq8B9be5rM6VJpxE73TKuzZBuc+rNq5PjNTqsovhrqiJTml1x0HnH1Zl3yml7XFdvrtKlhQ0mgl9FJrW47JDzjqs39zD4RaVKlxY2mJjqV5WlE3klRTftUrBoTr2jLrprm2FdsdYfVkWM+FVoTr0bdjvtlJKiYz5yyGUtJitdWtggYsSvQlnBrEazGt21zZBuC2rtcFuDWaPaY60/vJAIfpUb0m1Six6Ddruhy7ALXo/ghxcSwa9yk1pNalVU0mNQg1ntRs1qMKkljvnDc4ngrxNjtvrIMe1GtRvVacRlh/RF8MNzWPHJvZRSIaV0NqX0/fL9PSml0ymlyymlP04pxYvIGrqvzTX73LZDnXk73LbNXVuMK5qNK/mEZ/IsZ/V/Axcfuf97+HbO+RDu4ddWs7DweBNaXXbIR45pMOuU03a7oc58pUsL68iKgp9S6sXfwb9/ZPPfxJ+Xv/4u/v7qlhYeZ1KrKw654HVFJe941y43y919uXwL4cutdHr+b/Cv0AYppXZ8lnNenl8OoGf1ywtPMqfedXssKlhQ6w0fuqPDoB5TWipdXqhyTx3xU0p/FyM553N+9j2i8X7RCiopum6Pd71jXp03fOiAvmjyCSuykhH/a/jFlNLfQSNa8R20pZQK5VG/F7ef/Ct+9MjXe8q38CKygpIGJUV3dLhlpwW1eg1oMu2ubbHWv+lcL9+eLuW88mPClNJfx7/MOf9iSulP8J9zzn+SUvp/8UHO+d895mcyv7Pi5wjPKmspr+j3GrDXNdOanHfcTbsrXVyoqN+Vc37szPxFevW/hd9MKV3GVvzBC/yu8NySSa2GdRsrfzJvnXnb3NXjdly9NzzWM434z/UEMeK/NFuM22pMu1Ed7ihYLDf5HKx0aaEinjziR9PNBrJ8JZ859boN6XDn4ZV8S4pKinEln4B4W+6GdN8WnzjsY69pMu2U03a5qdZCpUsLVSKCvwEtN/lcckRRyc/5iV4Dikrl1t5o8tnsYqq/gZUUXbPXvDoP1DjhnBGdbtthWnOlywsVFCP+BjarwTV7nfG2RQVv+sABfZpNVbq0UGEx4m9gWcGcogW1RnQ+HP13u6HVhDs6oslnk4rgbwKLCm7b4Z5X7HbDAX0mtFpQG8HfpCL4m0IypcWUFi0mdRlW44Ft7nqgxri2uJLPJhPH+JvMHR0+8Kbbdugy7IRzOo1UuqzwksWIv8ksf1DnooJuQw87/cZsNavBrIZo8tkE4n94kxrX5pIjPnFYsylvO2OnW2o8qHRp4SWI4G9SE7boc9AVBzWYddx5PQYVldRYEE0+G1tM9Te5GY367TOjUZa85X3DugzojSafDSxG/E1uVoN++7zvLVly0ln79Gs0U+nSwhqKEX+Tywrm1VtUMKzLZYfMqbdPvzbj0eSzQcWIH7DU5DOg1xlvm9TikMuOuqjNeKVLC2sgRvyApZF/WrNpzdqM2+5TePjvuLYY+TeQCH74gk9tN6fedp/qMajbkI+9FsHfQCL44QuWm3ygx6A249qNGtdmRqNpTeLq6utbHOOHJ7rnFRcddcVBrSZ81Xt6DUSTzwYQwQ9PNGGLqw7ot09RyRGXdBnWaEat+figznUspvrhqaY1uWq/aU2S7C3vG9Ltlp1mNFW6vPAcYsQPTzWj0VX7nXVSkn3Ve/bp12C20qWF5xTBD0+VFTxQa1aDId0ueN2cegddsV+fVvcrXWJ4RjHVDyu2oNYtO93R4YA+R1wyWv70nljqW18i+GHFHm3yGTNmzFZZ0mlEwaJ7XokXgHUigh+ey4hOsxp0GtFrIJp81pkIfnguy00+BYu6DWk1YasxE1pNazKlWTT5VK84uRdeyGdeddFRV+3XZtxX/FiPQSku5FHVIvjhhUzYot9+N+xWVLLfVZ1GNJtSZy6afKpUTPXDqpjSrM8BE1rVeOBtZwzoddMusxorXV74nBjxw6pYDv45Jx4Gf69r0eRTpSL4YZUki2qUFN22wzknzKl32Cf269NiotIFhkfEVD+sqnl1btplRKcjLjnqohGdZjTGp/VUkQh+WFVZwYwmM5qMandHh0UFXYbVmTdma6z1V4EIflgzw7pMa9Jl2C43dRty0dEIfhVY0TF+SqktpfSnKaWPU0oXUkrvpJReTSn9IKX0SUrpL1JKbWtdbFhfJmxxW69R7R6o0WRau1GdhjWbFB/aUTkrPbn3HfzXnPNreBOX8C38Zc75MH6I316bEsN6N2arC465Zq9Xfeaks3oMVrqsTe2pwU8pbcH/nXP+Q8g5L+Scx/FNfLf8sO/i761ZlWFdu69Nv31u2anenN1u6HBHi0n1StHkUwErGfH34m5K6Q9TSmdTSr+fUmpCZ855BHLOw9i+loWG9W9KsysOOu+4OvNOOW2P6+rNVbq0TWclwa/FSfzbnPNJTFma5n/+AC0O2MKXmtTiskPOO67e3MPgF5UqXdqms5Kz+gO4lXP+cfn+n1sK/khKqTPnPJJS6qL8yQuP9aNHvt5TvoXNZ+ndeiVFN+1SsGhOvaMuumubYV2x1v9CrpdvT/fU4JeDfSuldCjnfBk/jwvl26/i9/Ar+N6Tf8vXV1RM2Bzm1Ltht2FdjrromI8M6jGpJYL/Qvb42UH1r574yJWu4/86/iilVId+/GPU4D+llP4JbuAfPkelYRPKCmY1mtXorm2GdFtQa4fbGswa1R5r/WtsRcHPOX+Arz7mW7+wuuWEzWZIt0ktegza7YYuwy54PYK/xqJzL1TUpFaTWhWV9BjUYFa7UbMaYuq/hiL4oSqM2eojx7QbLXf3jbjskL4I/pqI4IeqcF+b+9oeXsCzx6AxW31qu5KiOfVyvIt81cSeDFVlQqvLDvnIMQ1mnXLabjfUma90aRtKBD9UlUmtrjjkgtcVlbzjXbvcLHf3ZdEntjpiqh+q0px61+2xqGBBrTd86I4Og3pMaal0eetejPihKpUUXbfHu94xr84bPnRAnxaTlS5tQ4gRP1SlrKCkQUnRHR1u2WlBrV4Dmky7a1us9b+ACH6oeoN63LdFrwH79Osy7LzjEfwXEMEPVS49bPJpMqPbkDrztrlrQa37tkSTz3OIY/ywbtzV7rzjbtlpu0+dcE6X4UqXtS7FiB/WjeUmnzn1ug3pcOfhlXxLikqK0eSzQrGXwrpz3xafOOxjr2ky7ZTTdrmp1kKlS1s3Ivhh3Vlu8rnkiKKSn/MTvQYUH16/L5p8niam+mHdKim6Zq95dR6occI5IzrdtsO05kqXV9VixA/r1qwG1+x1xtsWFbzpAwf0aTZV6dKqXoz4Yd3KCuYULag1ovPh6L/bDa0m3NERa/1PEMEP696igtt2uOcVu91wQJ8JrRbURvCfIIIfNoBkSospLVpM6jKsxgPb3PVAjXFt0eTzOXGMHzaUOzp84E237dBl2AnndBqpdFlVJ0b8sKFM2GLCFosKug3Zaky7UWO2mtVgVkM0+YgRP2xQ49pccsQnDms25W1n7HRLjQeVLq0qRPDDhjRhiz4HXXFQg1nHnddjUFFJjQWbvcknpvphQ5vRqN8+MxplyVveN6zLgN5N3eQTI37Y0GY16LfP+96SJSedtU+/RjOVLq2iYsQPG1pWMK/eooJhXS47ZE69ffq1Gd+0TT4x4odNYVHBgF5nvG1Si0MuO+qiNuOVLq0iYsQPm0JWMK3ZtGZtxm0vf6r78r/j2jbVyB/BD5vOp7abU2+7T/UY1G3Ix16L4IewkS03+UCPQW3GtRs1rs2MRtOakCpb5BqLY/ywad3ziouOuuKgVhO+6j29BjZFk08EP2xaE7a46oB++xSVHHFJl2GNZtSaL1/NZ2OKqX7Y9KY1uWq/aU2S7C3vG9Ltlp1mNFW6vDURI37Y9GY0umq/s05Ksq96zz79GsxWurQ1E8EPm15W8ECtWQ2GdLvgdXPqHXTFfn1a3a90iasupvohlC2odctOd3Q4oM8Rl4xqN6thwy31rWjETyn9i5TSRymlD1NKf5RSqk8p7UkpnU4pXU4p/XFKKV5Ewrq23OQzpt2YrcZslSWdRux0c0ON/E8NfkqpB/8cJ3POb1iaJfwyfg/fzjkfwj382loWGsLLNKLTOSeM6NRrwBs+1OFOpctaNSs9xq9Bc3lUb8Qg/gb+vPz97+Lvr355IVTGhC0GytP+RQWtJmw1psOnmk1a7+/nf2rwc86D+DZu4jbGcRb3cs7LC50D6FmrIkOolM+86qKjrtqvzbiv+LEeg9I6D/5Tj8tTSq/gm9htKfR/im8829P86JGv95RvIVS/5fbeSS2+4sf26Teh1ZBupfI1/avnGn7Xy7enW8kJuV9Af855DFJK/wVfwysppUJ51O+1NBt4gq+vqJgQqtWUZn0OmNCqxgNvO2NAr5t2mdVY6fLK9vjZQfWvnvjIlbxU3cSplFJDSinh53EB/xP/oPyYX8H3nqPSENaF5eCfc+Jh8Pe6tm6bfFZyjH8Gf4af4ANLb1v6fXwLv5lSuoyt+IM1rDOECksW1Sgpum2Hc06YU++wT+zXp8VEpQt8JinntT1JkVLK/M6aPkcIL0uyqMHswzf1vObjh0t/g3ZUurzP+V0558e+vziabkJ4BlnBjCYzmoxqf7jc12VYnXljtq6LLr8IfgjPaViXaU26DNvlpm5DLjoawQ9hI1te6qszr8egJtPajZrWZFKLKc2q9Uo+1bIAGcK6NWarC465Zq9Xfeaks3oMVrqsLxUjfggv6L4297WZ0lR+Q88t49oM6Tan3ry6KmryWVJd1YSwjk1pdsVB5x1XZ94pp+1xXb25Spf2BRH8EFbJpBaXHXLecfXmHga/qFTp0r4gpvohrJqlE3klRTftUrBoTr2jLrprm2FdJrVWuMYlMeKHsMrm1Ltht9NOKSk65iOHXNZistKlPRQjfgirLCuY1WhWo7u2GdJtQa0dbmswa1R7xdf6I/ghrKEh3Sa16DFotxu6DLvg9Qh+CBvZpFaTWhWV9BjUYFa7UbMaTGqp2DF/BD+El2DMVh85pt2odqM6jbjskL4Ifggb13KTz6wGnUb0GDRmq09tV1I0p/6lNvnEWf0QXqIJrS475CPHNJh1ymm73VBn/qXWEcEP4SWa1OqKQy54XVHJO961y81yd1/2sq7eG1P9ECpgTr3r9lhUsKDWGz50R4dBPaa0rPnzx4gfQgWUFF23x7veMa/OGz50QN9La/KJET+ECsgKShqUFN3R4ZadFtTqNaDJtLu2relafwQ/hAob1OO+LXoN2Kdfl2HnHY/gh7BxpYdNPk1mdBtSZ942dy2odd+WNWnyiWP8EKrEXe3OO+6Wnbb71AnndBlek+eKET+EKrHc5DOnXrchHe48vJJvSVFJcdWafGLED6HK3LfFJw772GuaTDvltF1uqrWwas8RwQ+hyiw3+VxyRFHJz/mJXgOKSpJFq9HkE1P9EKpUSdE1e82r80CNE84Z0em2HaY1v9DvjhE/hCo1q8E1e53xtkUFb/rAAX2aTb3w744RP4QqlRXMKVpQa0Tnw9F/txtaTbij47nX+iP4IVS5RQW37XDPK3a74YA+E1otqI3gh7BxJVNaTGnRYlKXYTUe2OauB2qMa3vmJp84xg9hHbmjwwfedNsOXYadcE6nkWf+PTHih7COLH9Q56KCbkO2GtNu1JitZjWY1bCiJp8Y8UNYh8a1ueSITxzWbMrbztjplhoPVvTzEfwQ1qEJW/Q56IqDGsw67rweg4pKaix4WpNPTPVDWMdmNOq3z4xGWfKW9w3rMqDX9Jf83Esa8a+/nKdZVdcrXcAzul7pAp7D9UoX8ByuV7qAnzGrQb993veWLDnprH36NZr50p+L4D/R9UoX8IyuV7qA53C90gU8h+uVLuBnZAXz6k1rMqzLZYfMqbdP/5f+XBzjh7ABLCoY0OuMt01qccjlL318BD+EDSArmNbsrg7z6mxx/0sfn3Je2+t4p5RezoXCQwhfkHNOj9u+5sEPIVSfmOqHsAlF8EPYhNY0+Cmlb6SULqWULqeUfmstn+t5pZR6U0o/TCldSCmdTyn9enn7qymlH6SUPkkp/UVKqa3StT4qpVRIKZ1NKX2/fH9PSul0eV//cUqpqpqzUkptKaU/TSl9XN7X76yDffwvUkofpZQ+TCn9UUqpvtr380qtWfBTSgX8P/jbeB2/nFI6slbP9wIW8Js559fx1/BPy3V+C3+Zcz6MH+K3K1jj4/wGLj5y//fw7ZzzIdzDr1Wkqif7Dv5rzvk1vIlLqngfp5R68M9xMuf8hqUu119W/ft5ZXLOa3LDKfy3R+5/C7+1Vs+3inX///gFS3+YneVtXbhU6doeqbEX/wNfx/fL2+6g8Mi+/++VrvORerfg6mO2V/M+7sENvGop9N/H38Kn1bqfn+W2llP9Hbj1yP2B8raqlVLagxM4bekPcgRyzsPYXrnKvuDf4F8pvxMjpdSOz3LOi+XvD1j6w60We3E3pfSH5cOT308pNanifZxzHsS3cRO3MY6zuFfF+3nF4uReWUqpBX+G38g5T/ri25uqYt0zpfR3MZJzPodH12gfu15bJWpxEv8253wSU5ZmgFW5jyGl9Aq+id2Wwt2Mb1S0qFW0lsG/jV2P3O8tb6s65RM0f4b/mHP+XnnzSEqps/z9LktTvGrwNfxiSqkff4y/aen4ua18XoXq29cDuJVz/nH5/p9beiGo1n3M0uFef855LOf8AP/F0r5/pYr384qtZfDfw4GU0u6UUj1+ydJxUjX6D7iYc/7OI9u+j18tf/0r+N7nf6gScs7/Oue8K+e8z9I+/WHO+R/hf+IflB9WNfVCeTp/K6V0qLzp53FBle7jsps4lVJqSCklP625avfzM1njEyTfwCe4gm9V+oTGE2r8Gh7gHH5i6TjuG9iKvyzX/wO8UulaH1P7X/fTk3t78S4u409QV+n6Plfrm5YGg3P4z2ir9n2M38HH+BDfRV217+eV3qJlN4RNKE7uhbAJRfBD2IQi+CFsQhH8EDahCH4Im1AEP4RNKIIfwiYUwQ9hE/o/8wGCp2j2ioMAAAAASUVORK5CYII=\n",
"text/plain": [
"<matplotlib.figure.Figure at 0x7fcfbc9eca90>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"vectors = fox.predict(A=A)\n",
"\n",
"cosine = lambda a, b: (a * b).sum(-1) / (a * a).sum(-1) ** 0.5 / (b * b).sum(-1) ** 0.5\n",
"correlation_map = cosine(vectors.T[:, None, :], vectors.T[None, :, :])\n",
"\n",
"# orthogonality check\n",
"plt.imshow(correlation_map)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"coefficient of determination from known cols: 0.0007731837967437752\n",
"coefficient of determination from unknown cols: 0.9987369410432851\n"
]
}
],
"source": [
"from sklearn.linear_model import LinearRegression\n",
"regression = LinearRegression().fit(M[:, :-components], vectors)\n",
"print(\"coefficient of determination from known cols:\", regression.score(M[:, :-components], vectors))\n",
"regression = LinearRegression().fit(M[:, -components:], vectors)\n",
"print(\"coefficient of determination from unknown cols:\", regression.score(M[:, -components:], vectors))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.2"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
import numpy as np
from contextlib import contextmanager
from uuid import uuid4
import tensorflow as tf
L = tf.contrib.keras.layers
class MinFoxSolver:
def __init__(self, n_components, p, p_b=None,
gen_optimizer=tf.train.AdamOptimizer(5e-4), pred_optimizer=tf.train.AdamOptimizer(5e-4),
make_generator=lambda n_components: L.Dense(n_components, name='he_who_generates_unpredictable'),
make_predictor=lambda n_components: L.Dense(n_components, name='he_who_predicts_generated_variable'),
sess=None, device=None,
):
"""
Given two matrices A and B, predict a variable f(A) that is impossible to predict from matrix B
:param p: last dimension of A
:param p_b: dimension of B, default p_b = p
:param optimizer: tf optimizer to be used on both generator and discriminator
:param make_generator: callback to create a keras model for target variable generator given A
:param make_predictor: callback to create a keras model for target variable predictor given B
:param sess: tensorflow session. If not specified, uses default session or creates new one.
:param device: 'cpu', 'gpu' or a specific tf device to run on.
/* Маааленькая лисёнка */
"""
config = tf.ConfigProto(device_count={'GPU': int(device != 'cpu')}) if device is not None else tf.ConfigProto()
self.session = sess = sess or tf.get_default_session() or tf.Session()
self.n_components = n_components
self.gen_optimizer, self.pred_optimizer = gen_optimizer, pred_optimizer
with sess.as_default(), sess.graph.as_default(), tf.device(device), tf.variable_scope(str(uuid4())) as self.scope:
A = self.A = tf.placeholder(tf.float32, [None, p])
B = self.B = tf.placeholder(tf.float32, [None, p_b or p])
self.generator = make_generator(n_components)
self.predictor = make_predictor(n_components)
prediction = self.predictor(B)
target_raw = self.generator(A)
# orthogonalize target and scale to unit norm
target = orthogonalize_columns(target_raw)
target *= tf.sqrt(tf.to_float(tf.shape(target)[0]))
self.loss_values = self.compute_loss(target, prediction)
self.loss = tf.reduce_mean(self.loss_values)
self.reg = tf.reduce_mean(tf.squared_difference(target, target_raw))
with tf.variable_scope('gen_optimizer') as gen_optimizer_scope:
self.update_gen = gen_optimizer.minimize(-self.loss + self.reg,
var_list=self.generator.trainable_variables)
with tf.variable_scope('pred_optimizer') as pred_optimizer_scope:
self.update_pred = pred_optimizer.minimize(self.loss,
var_list=self.predictor.trainable_variables)
pred_state = self.predictor.trainable_variables
pred_state += tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=tf.get_variable_scope().name)
self.reset_pred = tf.variables_initializer(pred_state)
self.prediction, self.target = prediction, target
self.target_raw = target_raw
def compute_loss(self, target, prediction):
""" Return loss function for each sample and for component, output.shape == target.shape """
return tf.squared_difference(target, prediction)
def fit(self, A, B, max_iters=10 ** 4, tolerance=1e-4, batch_size=None, pred_steps=5, gen_steps=1,
warm_start=False, reset_predictor=False, reorder=True, verbose=False, report_every=100):
"""
Trains the fox
:param pred_steps: predictor g(B) training iterations per one training step
:param gen_steps: generator f(A) training iterations per one training step
:param max_iters: maximum number of optimization steps till termination
:param tolerance: terminates if loss difference between 10-iteration cycles reaches this value
set to 0 to iterate for max_steps
:param reset_predictor: if True, resets predictor network after every step
:param reorder: if True, reorders components from highest loss to lowest
"""
sess = self.session
step = 0
with sess.as_default(), sess.graph.as_default():
initialize_uninitialized_variables(sess)
prev_loss = float('inf')
for batch_a, batch_b in iterate_minibatches(A, B, batch_size, cycle=True, shuffle=True):
step += 1
if step > max_iters: break
feed = {self.A: batch_a, self.B: batch_b}
# train predictor
for j in range(pred_steps):
sess.run(self.update_pred, feed)
# eval loss and metrics
if step % report_every == 0:
loss_t = sess.run(self.loss, feed)
if verbose:
print("step %i; loss=%.4f; delta=%.4f" % (step, loss_t, abs(prev_loss - loss_t)))
if abs(prev_loss - loss_t) < tolerance:
if verbose: print("Done: reached target tolerance")
break
prev_loss = loss_t
# update generator
for j in range(gen_steps):
sess.run(self.update_gen, feed)
if reset_predictor:
sess.run(self.reset_pred)
else:
if verbose:
print("Done: reached max steps")
# record components ordered by their loss value (highest to lowest)
if reorder:
if verbose:
print("Ordering components by loss values...")
self.loss_per_component = np.zeros([self.n_components])
for batch_a, batch_b in iterate_minibatches(A, B, batch_size, cycle=False, shuffle=False):
batch_loss_values = sess.run(self.loss_values, {self.A: batch_a, self.B: batch_b})
# ^-- [batch_size, n_components]
self.loss_per_component += batch_loss_values.sum(0)
self.component_order = np.argsort(-self.loss_per_component)
self.loss_per_component_ordered = self.loss_per_component[self.component_order]
else:
self.component_order = np.arange(self.n_components)
if verbose:
print("Training finished.")
return self
def predict(self, A=None, B=None, ordered=True, raw=False):
assert (A is None) != (B is None), "Please use either predict(A=...) or predict(B=...)"
sess = self.session
with sess.as_default(), sess.graph.as_default():
if A is not None:
if not raw:
out = sess.run(self.target, {self.A: A})
else:
out = sess.run(self.target_raw, {self.A: A})
else:
out = sess.run(self.prediction, {self.B: B})
if ordered:
out = out[:, self.component_order]
return out
def get_weights(self):
return self.session.run({'generator': self.generator.trainable_variables,
'predictor': self.predictor.trainable_variables})
def orthogonalize_rows(matrix):
"""
Gram-shmidt orthogonalizer for each row of matrix; source: https://bit.ly/2FMOp40
:param matrix: 2d float tensor [nrow, ncol]
:returns: row-orthogonalized matrix [nrow, ncol] s.t.
* output[i, :].dot(output[j, :]) ~= 0 for all i != j
* norm(output[i, :]) == 1 for all i
"""
basis = tf.expand_dims(matrix[0, :] / tf.norm(matrix[0, :]), 0) # [1, ncol]
for i in range(1, matrix.shape[0]):
v = tf.expand_dims(matrix[i, :], 0) # [1, ncol]
w = v - tf.matmul(tf.matmul(v, basis, transpose_b=True), basis) # [1, ncol]
basis = tf.concat([basis, w / tf.norm(w)], axis=0) # [i, ncol]
return basis
def orthogonalize_columns(matrix):
"""
Gram-shmidt orthogonalizer for each row of matrix; source: https://bit.ly/2FMOp40
:param matrix: 2d float tensor [nrow, ncol]
:returns: column-orthogonalized matrix [nrow, ncol] s.t.
* output[:, i].dot(output[:, j]) ~= 0 for all i != j
* norm(output[:, j]) == 1 for all i
"""
basis = tf.expand_dims(matrix[:, 0] / tf.norm(matrix[:, 0]), 1) # [nrow, 1]
for i in range(1, matrix.shape[1]):
v = tf.expand_dims(matrix[:, i], 1) # [nrow, 1]
w = v - tf.matmul(basis, tf.matmul(basis, v, transpose_a=True)) # [nrow, 1]
basis = tf.concat([basis, w / tf.norm(w)], axis=1) # [nrow, i]
return basis
def initialize_uninitialized_variables(sess=None, var_list=None):
with tf.name_scope("initialize"):
sess = sess or tf.get_default_session() or tf.InteractiveSession()
uninitialized_names = set(sess.run(tf.report_uninitialized_variables(var_list)))
uninitialized_vars = []
for var in tf.global_variables():
if var.name[:-2].encode() in uninitialized_names:
uninitialized_vars.append(var)
sess.run(tf.variables_initializer(uninitialized_vars))
def iterate_minibatches(x, y, batch_size=None, cycle=False, shuffle=False):
indices = np.arange(len(x))
while True:
if batch_size is not None:
if shuffle:
indices = np.random.permutation(indices)
for batch_start in range(0, len(x), batch_size):
batch_ix = indices[batch_start: batch_start + batch_size]
yield x[batch_ix], y[batch_ix]
else:
yield x, y
if not cycle:
break
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment