Created
January 24, 2019 09:52
-
-
Save justheuristic/1e90f65371bed3073ad9fb1abb957c42 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Doublefox example\n", | |
"![Img](https://i.imgur.com/IkU0BxB.png?2)\n", | |
"\n", | |
"__Patch notes:__\n", | |
"* Добавились двойные градиенты\n", | |
" * `MinFoxSolver(..., double_grad_steps)` - двойные градиенты + обычный minmax update\n", | |
" * `DoubleFoxSolver(...)` - только двойные градиенты и ничего лишнего\n", | |
"* Заменил while_loop на unroll, *внезапно* оно заработало (видимо баг в tf1.3 control flow ops)\n", | |
"* Предиктор обновляется с sgd; если использовать adam, работают нестабильно\n", | |
"\n", | |
"\n", | |
"* Теперь можно сочетать двойные градиенты и обычный minmax update. По моим приборам это сильно ускоряет сходимость если predictor многослойный.\n", | |
" * `MinFoxSolver(..., double_grad_steps=N).fit(..., pred_steps=1)` или больше\n", | |
" * предиктор будет на каждом шаге приходить в чуть более точное начальное приближение\n", | |
"\n", | |
"* Tech stuff\n", | |
" * появился параметр min_iters - гарантированное минимальное число итераций без проверки tolerance\n", | |
" * теперь генератор/предиктор обязаны быть моделями tfnn, а не keras. С керасом двойные градиенты не завелись.\n", | |
" * если предиктор не обучается, выводится лосс только для генератора (в точке \"после double grad обучения предиктора\"), если обучается - оба лосса.\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import matplotlib.pyplot as plt\n", | |
"%matplotlib inline\n", | |
"import numpy as np\n", | |
"\n", | |
"num_samples, true_dim, projected_dim, components = 2000, 20, 50, 10\n", | |
"\n", | |
"M = np.random.randn(num_samples, true_dim).astype('float32')\n", | |
"A = M.dot(np.random.randn(true_dim, projected_dim).astype('float32'))\n", | |
"B = M[:, :-components].dot(np.random.randn(true_dim - components, projected_dim).astype('float32'))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from minifox import MinFoxSolver, DoubleFoxSolver" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": { | |
"scrolled": true | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"step 10; loss=0.6451; delta=inf\n", | |
"step 20; loss=0.6223; delta=0.0228\n", | |
"step 30; loss=0.5927; delta=0.0296\n", | |
"step 40; loss=0.5926; delta=0.0001\n", | |
"step 50; loss=0.6496; delta=0.0570\n", | |
"step 60; loss=0.5735; delta=0.0762\n", | |
"step 70; loss=0.6472; delta=0.0737\n", | |
"step 80; loss=0.6011; delta=0.0461\n", | |
"step 90; loss=0.6390; delta=0.0379\n", | |
"step 100; loss=0.6055; delta=0.0336\n", | |
"step 110; loss=0.6315; delta=0.0260\n", | |
"step 120; loss=0.7394; delta=0.1080\n", | |
"step 130; loss=0.6500; delta=0.0894\n", | |
"step 140; loss=0.7142; delta=0.0642\n", | |
"step 150; loss=0.7392; delta=0.0251\n", | |
"step 160; loss=0.8023; delta=0.0630\n", | |
"step 170; loss=0.8075; delta=0.0052\n", | |
"step 180; loss=0.8172; delta=0.0097\n", | |
"step 190; loss=0.9127; delta=0.0956\n", | |
"step 200; loss=0.7946; delta=0.1181\n", | |
"step 210; loss=0.8122; delta=0.0176\n", | |
"step 220; loss=0.8569; delta=0.0447\n", | |
"step 230; loss=0.9695; delta=0.1126\n", | |
"step 240; loss=1.0084; delta=0.0389\n", | |
"step 250; loss=0.8318; delta=0.1766\n", | |
"step 260; loss=0.9548; delta=0.1230\n", | |
"step 270; loss=0.9658; delta=0.0110\n", | |
"step 280; loss=0.9637; delta=0.0021\n", | |
"step 290; loss=1.0542; delta=0.0906\n", | |
"step 300; loss=0.8983; delta=0.1559\n", | |
"step 310; loss=1.0395; delta=0.1412\n", | |
"step 320; loss=0.9934; delta=0.0461\n", | |
"step 330; loss=0.9783; delta=0.0151\n", | |
"step 340; loss=1.1710; delta=0.1927\n", | |
"step 350; loss=1.0778; delta=0.0931\n", | |
"step 360; loss=1.0159; delta=0.0620\n", | |
"step 370; loss=1.1092; delta=0.0933\n", | |
"step 380; loss=1.0549; delta=0.0543\n", | |
"step 390; loss=1.0246; delta=0.0303\n", | |
"step 400; loss=0.9764; delta=0.0482\n", | |
"step 410; loss=1.0605; delta=0.0842\n", | |
"step 420; loss=1.0842; delta=0.0236\n", | |
"step 430; loss=1.0785; delta=0.0057\n", | |
"step 440; loss=1.0400; delta=0.0385\n", | |
"step 450; loss=1.0236; delta=0.0164\n", | |
"step 460; loss=1.1403; delta=0.1167\n", | |
"step 470; loss=1.0166; delta=0.1238\n", | |
"step 480; loss=1.0628; delta=0.0463\n", | |
"step 490; loss=1.0529; delta=0.0099\n", | |
"step 500; loss=1.2188; delta=0.1659\n", | |
"step 510; loss=1.0563; delta=0.1625\n", | |
"step 520; loss=1.1194; delta=0.0631\n", | |
"step 530; loss=1.0821; delta=0.0373\n", | |
"step 540; loss=1.0635; delta=0.0186\n", | |
"step 550; loss=1.0930; delta=0.0295\n", | |
"step 560; loss=1.0918; delta=0.0011\n", | |
"step 570; loss=1.1540; delta=0.0621\n", | |
"step 580; loss=1.1254; delta=0.0286\n", | |
"step 590; loss=1.0980; delta=0.0274\n", | |
"step 600; loss=1.1050; delta=0.0070\n", | |
"step 610; loss=1.0342; delta=0.0708\n", | |
"step 620; loss=1.0396; delta=0.0054\n", | |
"step 630; loss=1.0488; delta=0.0092\n", | |
"step 640; loss=1.1031; delta=0.0543\n", | |
"step 650; loss=1.0539; delta=0.0492\n", | |
"step 660; loss=1.0860; delta=0.0321\n", | |
"step 670; loss=1.0799; delta=0.0060\n", | |
"step 680; loss=1.0947; delta=0.0148\n", | |
"step 690; loss=1.0840; delta=0.0107\n", | |
"step 700; loss=1.0348; delta=0.0492\n", | |
"step 710; loss=1.0513; delta=0.0165\n", | |
"step 720; loss=1.1529; delta=0.1016\n", | |
"step 730; loss=1.1441; delta=0.0087\n", | |
"step 740; loss=1.1120; delta=0.0321\n", | |
"step 750; loss=1.0421; delta=0.0699\n", | |
"step 760; loss=1.1786; delta=0.1365\n", | |
"step 770; loss=1.0416; delta=0.1370\n", | |
"step 780; loss=1.0709; delta=0.0293\n", | |
"step 790; loss=1.0629; delta=0.0080\n", | |
"step 800; loss=1.0460; delta=0.0169\n", | |
"step 810; loss=1.0882; delta=0.0421\n", | |
"step 820; loss=1.1138; delta=0.0256\n", | |
"step 830; loss=1.0700; delta=0.0438\n", | |
"step 840; loss=1.0305; delta=0.0395\n", | |
"step 850; loss=1.1153; delta=0.0847\n", | |
"step 860; loss=1.0459; delta=0.0694\n", | |
"step 870; loss=1.0974; delta=0.0515\n", | |
"step 880; loss=1.1377; delta=0.0403\n", | |
"step 890; loss=1.1408; delta=0.0031\n", | |
"step 900; loss=1.1732; delta=0.0324\n", | |
"step 910; loss=1.0606; delta=0.1126\n", | |
"step 920; loss=1.1232; delta=0.0627\n", | |
"step 930; loss=1.1775; delta=0.0543\n", | |
"step 940; loss=1.0712; delta=0.1064\n", | |
"step 950; loss=1.0966; delta=0.0254\n", | |
"step 960; loss=1.1672; delta=0.0706\n", | |
"step 970; loss=1.0579; delta=0.1093\n", | |
"step 980; loss=1.1573; delta=0.0994\n", | |
"step 990; loss=1.0828; delta=0.0745\n", | |
"step 1000; loss=1.1045; delta=0.0217\n", | |
"step 1010; loss=1.1124; delta=0.0079\n", | |
"step 1020; loss=1.0317; delta=0.0806\n", | |
"step 1030; loss=1.1844; delta=0.1527\n", | |
"step 1040; loss=1.0541; delta=0.1303\n", | |
"step 1050; loss=1.0499; delta=0.0042\n", | |
"step 1060; loss=1.1265; delta=0.0766\n", | |
"step 1070; loss=1.0569; delta=0.0696\n", | |
"step 1080; loss=1.0000; delta=0.0568\n", | |
"step 1090; loss=1.1417; delta=0.1416\n", | |
"step 1100; loss=1.0832; delta=0.0584\n", | |
"step 1110; loss=1.0832; delta=0.0001\n", | |
"Done: reached target tolerance\n", | |
"Ordering components by loss values...\n", | |
"Training finished.\n", | |
"CPU times: user 3min 19s, sys: 12 s, total: 3min 31s\n", | |
"Wall time: 2min\n" | |
] | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"fox = DoubleFoxSolver(n_components=components, p=projected_dim, \n", | |
" double_grad_steps=20, double_grad_lr=0.01)\n", | |
"fox.fit(A, B,\n", | |
" batch_size=None, # None means full data\n", | |
" max_iters=10000, tolerance=1e-4,\n", | |
" verbose=True, report_every=10)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"<matplotlib.image.AxesImage at 0x7f001ef20fd0>" | |
] | |
}, | |
"execution_count": 4, | |
"metadata": {}, | |
"output_type": "execute_result" | |
}, | |
{ | |
"data": { | |
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAPgAAAD8CAYAAABaQGkdAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAACeNJREFUeJzt3N+L5XUdx/Hnq9lVWyvs1427S2sQlQRlDGIKXWjQT/KmCwODvNmbMo1AtJv+gRC7iGAxvVHyYvMiQrLox0U3W+Mq1ToVYqWbhluQhZG71ruLmWAzd853dr5fv3PePB8g7MyePb5Y5rnfc86c+aSqkNTTa+YeIGk6Bi41ZuBSYwYuNWbgUmMGLjVm4FJjBi41ZuBSY3umuNO3vGmlDh3cO/r9/vYX+0a/T2kZ/ZMXOF0vZtHtJgn80MG9/Ozhg6Pf74cvfd/o9ykto2P1w0G38yG61JiBS40ZuNSYgUuNGbjUmIFLjQ0KPMlHkvwmyRNJbp96lKRxLAw8yQrwdeCjwOXAp5NcPvUwSTs35Ap+JfBEVT1ZVaeBB4Drp50laQxDAt8PPH3Wxyc3P/c/khxOspZk7dRf/jXWPkk7MCTwV3q/6/8dxVpVR6pqtapW3/rmlZ0vk7RjQwI/CZz9xvIDwDPTzJE0piGB/xx4R5LLklwA3AB8Z9pZksaw8KfJquqlJJ8HHgZWgHuq6sTkyyTt2KAfF62qh4CHJt4iaWS+k01qzMClxgxcaszApcYMXGpskkMXf/uLfZMckPjwM4+Nfp/gYY7qyyu41JiBS40ZuNSYgUuNGbjUmIFLjRm41JiBS40ZuNSYgUuNGbjUmIFLjRm41JiBS40ZuNSYgUuNGbjUmIFLjRm41JiBS40ZuNTYJKeqTmWq00+nOK3Vk1q1G3gFlxozcKkxA5caM3CpMQOXGjNwqbGFgSc5mOTHSdaTnEhyy6sxTNLODfk++EvAl6rqeJLXA48k+UFVPT7xNkk7tPAKXlXPVtXxzV//HVgH9k89TNLObes5eJJDwBXAsSnGSBrX4LeqJnkd8G3g1qr62yv8/mHgMMBF7BttoKTzN+gKnmQvG3HfX1UPvtJtqupIVa1W1epeLhxzo6TzNORV9ADfBNar6s7pJ0kay5Ar+DXAZ4Brkzy2+d/HJt4laQQLn4NX1U+BvApbJI3Md7JJjRm41JiBS40ZuNSYgUuNLdWhi1OZ4oDEKQ5yBA9z1PZ4BZcaM3CpMQOXGjNwqTEDlxozcKkxA5caM3CpMQOXGjNwqTEDlxozcKkxA5caM3CpMQOXGjNwqTEDlxozcKkxA5caM3CpMQOXGvNU1YlMdfqpp7VqO7yCS40ZuNSYgUuNGbjUmIFLjRm41JiBS40NDjzJSpJHk3x3ykGSxrOdK/gtwPpUQySNb1DgSQ4AHwfunnaOpDENvYLfBdwG/PtcN0hyOMlakrUzvDjKOEk7szDwJJ8AnquqR7a6XVUdqarVqlrdy4WjDZR0/oZcwa8BPpnk98ADwLVJ7pt0laRRLAy8qu6oqgNVdQi4AfhRVd04+TJJO+b3waXGtvXz4FX1E+AnkyyRNDqv4FJjBi41ZuBSYwYuNWbgUmOeqrpkPK1V2+EVXGrMwKXGDFxqzMClxgxcaszApcYMXGrMwKXGDFxqzMClxgxcaszApcYMXGrMwKXGDFxqzMClxgxcaszApcYMXGrMwKXGDFxqzFNVBXhaa1dewaXGDFxqzMClxgxcaszApcYMXGpsUOBJLklyNMmvk6wn+cDUwyTt3NDvg38N+F5VfSrJBcC+CTdJGsnCwJO8Afgg8FmAqjoNnJ52lqQxDHmI/nbgFHBvkkeT3J3k4ol3SRrBkMD3AO8HvlFVVwAvALe//EZJDidZS7J2hhdHninpfAwJ/CRwsqqObX58lI3g/0dVHamq1apa3cuFY26UdJ4WBl5VfwKeTvLOzU9dBzw+6SpJoxj6KvrNwP2br6A/Cdw03SRJYxkUeFU9BqxOvEXSyHwnm9SYgUuNGbjUmIFLjRm41JiBS415qqomtUyntXY8qdUruNSYgUuNGbjUmIFLjRm41JiBS40ZuNSYgUuNGbjUmIFLjRm41JiBS40ZuNSYgUuNGbjUmIFLjRm41JiBS40ZuNSYgUuNeeiiltIUByROcZAjzHuYo1dwqTEDlxozcKkxA5caM3CpMQOXGjNwqbFBgSf5YpITSX6V5FtJLpp6mKSdWxh4kv3AF4DVqnoPsALcMPUwSTs39CH6HuC1SfYA+4BnppskaSwLA6+qPwJfBZ4CngWer6rvv/x2SQ4nWUuydoYXx18qaduGPER/I3A9cBlwKXBxkhtffruqOlJVq1W1upcLx18qaduGPET/EPC7qjpVVWeAB4Grp50laQxDAn8KuCrJviQBrgPWp50laQxDnoMfA44Cx4Ffbv6ZIxPvkjSCQT8PXlVfAb4y8RZJI/OdbFJjBi41ZuBSYwYuNWbgUmOeqiptmur00ylOa73yw/8YdDuv4FJjBi41ZuBSYwYuNWbgUmMGLjVm4FJjBi41ZuBSYwYuNWbgUmMGLjVm4FJjBi41ZuBSYwYuNWbgUmMGLjVm4FJjBi41ZuBSY6mq8e80OQX8YcBN3wL8efQB01mmvcu0FZZr727Y+raqeuuiG00S+FBJ1qpqdbYB27RMe5dpKyzX3mXa6kN0qTEDlxqbO/AjM///t2uZ9i7TVliuvUuzddbn4JKmNfcVXNKEZgs8yUeS/CbJE0lun2vHIkkOJvlxkvUkJ5LcMvemIZKsJHk0yXfn3rKVJJckOZrk15t/xx+Ye9NWknxx8+vgV0m+leSiuTdtZZbAk6wAXwc+ClwOfDrJ5XNsGeAl4EtV9W7gKuBzu3jr2W4B1uceMcDXgO9V1buA97KLNyfZD3wBWK2q9wArwA3zrtraXFfwK4EnqurJqjoNPABcP9OWLVXVs1V1fPPXf2fjC3D/vKu2luQA8HHg7rm3bCXJG4APAt8EqKrTVfXXeVcttAd4bZI9wD7gmZn3bGmuwPcDT5/18Ul2eTQASQ4BVwDH5l2y0F3AbcC/5x6ywNuBU8C9m08n7k5y8dyjzqWq/gh8FXgKeBZ4vqq+P++qrc0VeF7hc7v65fwkrwO+DdxaVX+be8+5JPkE8FxVPTL3lgH2AO8HvlFVVwAvALv59Zg3svFI8zLgUuDiJDfOu2prcwV+Ejh41scH2MUPdZLsZSPu+6vqwbn3LHAN8Mkkv2fjqc+1Se6bd9I5nQROVtV/HxEdZSP43epDwO+q6lRVnQEeBK6eedOW5gr858A7klyW5AI2Xqj4zkxbtpQkbDxHXK+qO+fes0hV3VFVB6rqEBt/rz+qql15lamqPwFPJ3nn5qeuAx6fcdIiTwFXJdm3+XVxHbv4RUHYeIj0qquql5J8HniYjVci76mqE3NsGeAa4DPAL5M8tvm5L1fVQzNu6uRm4P7Nf+ifBG6aec85VdWxJEeB42x8d+VRdvm72nwnm9SY72STGjNwqTEDlxozcKkxA5caM3CpMQOXGjNwqbH/AAcfJCfPJk4KAAAAAElFTkSuQmCC\n", | |
"text/plain": [ | |
"<matplotlib.figure.Figure at 0x7f001efd8a90>" | |
] | |
}, | |
"metadata": { | |
"needs_background": "light" | |
}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"vectors = fox.predict(A=A)\n", | |
"\n", | |
"cosine = lambda a, b: (a * b).sum(-1) / (a * a).sum(-1) ** 0.5 / (b * b).sum(-1) ** 0.5\n", | |
"correlation_map = cosine(vectors.T[:, None, :], vectors.T[None, :, :])\n", | |
"\n", | |
"# orthogonality check\n", | |
"plt.imshow(correlation_map)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"coefficient of determination from known cols: 0.00045179598686428474\n", | |
"coefficient of determination from unknown cols: 0.9942116432725838\n" | |
] | |
} | |
], | |
"source": [ | |
"from sklearn.linear_model import LinearRegression\n", | |
"regression = LinearRegression().fit(M[:, :-components], vectors)\n", | |
"print(\"coefficient of determination from known cols:\", regression.score(M[:, :-components], vectors))\n", | |
"regression = LinearRegression().fit(M[:, -components:], vectors)\n", | |
"print(\"coefficient of determination from unknown cols:\", regression.score(M[:, -components:], vectors))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Эмпирически \"хорошие\" параметры" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"step 10; loss(gen update)=0.5658; loss(pred update)=0.5691; delta=inf\n", | |
"step 20; loss(gen update)=0.5614; loss(pred update)=0.5631; delta=0.0044\n", | |
"step 30; loss(gen update)=0.5660; loss(pred update)=0.5676; delta=0.0046\n", | |
"step 40; loss(gen update)=0.5715; loss(pred update)=0.5730; delta=0.0055\n", | |
"step 50; loss(gen update)=0.5781; loss(pred update)=0.5797; delta=0.0066\n", | |
"step 60; loss(gen update)=0.5865; loss(pred update)=0.5881; delta=0.0084\n", | |
"step 70; loss(gen update)=0.5971; loss(pred update)=0.5989; delta=0.0106\n", | |
"step 80; loss(gen update)=0.6098; loss(pred update)=0.6117; delta=0.0127\n", | |
"step 90; loss(gen update)=0.6236; loss(pred update)=0.6255; delta=0.0138\n", | |
"step 100; loss(gen update)=0.6376; loss(pred update)=0.6395; delta=0.0140\n", | |
"step 110; loss(gen update)=0.6519; loss(pred update)=0.6537; delta=0.0142\n", | |
"step 120; loss(gen update)=0.6669; loss(pred update)=0.6688; delta=0.0150\n", | |
"step 130; loss(gen update)=0.6832; loss(pred update)=0.6852; delta=0.0163\n", | |
"step 140; loss(gen update)=0.7010; loss(pred update)=0.7031; delta=0.0178\n", | |
"step 150; loss(gen update)=0.7201; loss(pred update)=0.7223; delta=0.0191\n", | |
"step 160; loss(gen update)=0.7403; loss(pred update)=0.7425; delta=0.0202\n", | |
"step 170; loss(gen update)=0.7613; loss(pred update)=0.7635; delta=0.0210\n", | |
"step 180; loss(gen update)=0.7825; loss(pred update)=0.7847; delta=0.0212\n", | |
"step 190; loss(gen update)=0.8035; loss(pred update)=0.8057; delta=0.0210\n", | |
"step 200; loss(gen update)=0.8239; loss(pred update)=0.8260; delta=0.0204\n", | |
"step 210; loss(gen update)=0.8433; loss(pred update)=0.8452; delta=0.0194\n", | |
"step 220; loss(gen update)=0.8615; loss(pred update)=0.8633; delta=0.0182\n", | |
"step 230; loss(gen update)=0.8784; loss(pred update)=0.8800; delta=0.0169\n", | |
"step 240; loss(gen update)=0.8939; loss(pred update)=0.8955; delta=0.0156\n", | |
"step 250; loss(gen update)=0.9082; loss(pred update)=0.9096; delta=0.0143\n", | |
"step 260; loss(gen update)=0.9212; loss(pred update)=0.9225; delta=0.0130\n", | |
"step 270; loss(gen update)=0.9329; loss(pred update)=0.9340; delta=0.0117\n", | |
"step 280; loss(gen update)=0.9433; loss(pred update)=0.9443; delta=0.0104\n", | |
"step 290; loss(gen update)=0.9525; loss(pred update)=0.9534; delta=0.0092\n", | |
"step 300; loss(gen update)=0.9605; loss(pred update)=0.9613; delta=0.0080\n", | |
"step 310; loss(gen update)=0.9674; loss(pred update)=0.9680; delta=0.0069\n", | |
"step 320; loss(gen update)=0.9733; loss(pred update)=0.9738; delta=0.0059\n", | |
"step 330; loss(gen update)=0.9782; loss(pred update)=0.9786; delta=0.0049\n", | |
"step 340; loss(gen update)=0.9822; loss(pred update)=0.9826; delta=0.0041\n", | |
"step 350; loss(gen update)=0.9856; loss(pred update)=0.9859; delta=0.0033\n", | |
"step 360; loss(gen update)=0.9883; loss(pred update)=0.9886; delta=0.0027\n", | |
"step 370; loss(gen update)=0.9905; loss(pred update)=0.9907; delta=0.0022\n", | |
"step 380; loss(gen update)=0.9923; loss(pred update)=0.9925; delta=0.0018\n", | |
"step 390; loss(gen update)=0.9937; loss(pred update)=0.9939; delta=0.0014\n", | |
"step 400; loss(gen update)=0.9949; loss(pred update)=0.9950; delta=0.0012\n", | |
"step 410; loss(gen update)=0.9958; loss(pred update)=0.9959; delta=0.0009\n", | |
"step 420; loss(gen update)=0.9965; loss(pred update)=0.9966; delta=0.0007\n", | |
"step 430; loss(gen update)=0.9971; loss(pred update)=0.9972; delta=0.0006\n", | |
"step 440; loss(gen update)=0.9976; loss(pred update)=0.9977; delta=0.0005\n", | |
"step 450; loss(gen update)=0.9980; loss(pred update)=0.9981; delta=0.0004\n", | |
"step 460; loss(gen update)=0.9983; loss(pred update)=0.9984; delta=0.0003\n", | |
"step 470; loss(gen update)=0.9986; loss(pred update)=0.9986; delta=0.0003\n", | |
"step 480; loss(gen update)=0.9988; loss(pred update)=0.9989; delta=0.0002\n", | |
"step 490; loss(gen update)=0.9990; loss(pred update)=0.9990; delta=0.0002\n", | |
"step 500; loss(gen update)=0.9992; loss(pred update)=0.9992; delta=0.0002\n", | |
"step 510; loss(gen update)=0.9993; loss(pred update)=0.9993; delta=0.0001\n", | |
"step 520; loss(gen update)=0.9994; loss(pred update)=0.9994; delta=0.0001\n", | |
"step 530; loss(gen update)=0.9995; loss(pred update)=0.9995; delta=0.0001\n", | |
"Done: reached target tolerance\n", | |
"Ordering components by loss values...\n", | |
"Training finished.\n", | |
"CPU times: user 1min, sys: 2.84 s, total: 1min 3s\n", | |
"Wall time: 37 s\n" | |
] | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"import tensorflow as tf\n", | |
"\n", | |
"fox = MinFoxSolver(n_components=components, p=projected_dim,\n", | |
" gen_optimizer=tf.train.AdamOptimizer(1e-3),\n", | |
" double_grad_steps=5, double_grad_lr=0.01)\n", | |
"fox.fit(A, B,\n", | |
" batch_size=None, # None means full data\n", | |
" pred_steps=5,\n", | |
" max_iters=10000, tolerance=1e-4,\n", | |
" verbose=True, report_every=10)" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.2" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from contextlib import contextmanager | |
from uuid import uuid4 | |
import tensorflow as tf | |
import tfnn.layers.basic as L | |
class MinFoxSolver: | |
def __init__(self, n_components, p, p_b=None, double_grad_steps=0, double_grad_lr=0.01, | |
gen_optimizer=tf.train.AdamOptimizer(5e-4), pred_optimizer=tf.train.AdamOptimizer(5e-4), | |
make_generator=lambda name, inp_size, out_size: L.Dense(name, inp_size, out_size, activ=lambda x: x), | |
make_predictor=lambda name, inp_size, out_size: L.Dense(name, inp_size, out_size, activ=lambda x: x, | |
matrix_initializer=tf.zeros_initializer()), | |
sess=None, device=None, | |
): | |
""" | |
Given two matrices A and B, predict a variable f(A) that is impossible to predict from matrix B | |
:param p: last dimension of A | |
:param p_b: dimension of B, default p_b = p | |
:param gen_optimizer: tf optimizer used to train generator | |
:param pred_optimizer: tf optimizer used to train predictor out-of-graph (between iterations) | |
:param make_generator: callback to create a keras model for target variable generator given A | |
:param make_predictor: callback to create a keras model for target variable predictor given B | |
Both models returned by make_generator and make_predictor should have | |
all their trainable variables created inside name scope (first arg) | |
:param sess: tensorflow session. If not specified, uses default session or creates new one. | |
:param device: 'cpu', 'gpu' or a specific tf device to run on. | |
/* Маааленькая лисёнка */ | |
""" | |
config = tf.ConfigProto(device_count={'GPU': int(device != 'cpu')}) if device is not None else tf.ConfigProto() | |
self.session = sess = sess or tf.get_default_session() or tf.Session() | |
self.n_components = n_components | |
self.gen_optimizer, self.pred_optimizer = gen_optimizer, pred_optimizer | |
with sess.as_default(), sess.graph.as_default(), tf.device(device), tf.variable_scope(str(uuid4())) as self.scope: | |
A = self.A = tf.placeholder(tf.float32, [None, p]) | |
B = self.B = tf.placeholder(tf.float32, [None, p_b or p]) | |
self.generator = make_generator('generator', p, n_components) | |
self.predictor = make_predictor('predictor', p_b or p, n_components) | |
prefix = tf.get_variable_scope().name + '/' | |
self.generator_weights = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, | |
scope=prefix + 'generator') | |
self.predictor_weights = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, | |
scope=prefix + 'predictor') | |
prediction = self.predictor(B) | |
target_raw = self.generator(A) | |
# orthogonalize target and scale to unit norm | |
target = orthogonalize_columns(target_raw) | |
target *= tf.sqrt(tf.to_float(tf.shape(target)[0])) | |
self.loss_values = self.compute_loss(target, prediction) | |
self.loss = tf.reduce_mean(self.loss_values) | |
# update predictor | |
with tf.variable_scope('pred_optimizer') as pred_optimizer_scope: | |
self.update_pred = pred_optimizer.minimize(self.loss, | |
var_list=self.predictor_weights) | |
pred_state = list(self.predictor_weights) | |
pred_state += tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=tf.get_variable_scope().name) | |
self.reset_pred = tf.variables_initializer(pred_state) | |
# update generator using (possibly) ingraph-updated generator | |
if double_grad_steps == 0: | |
generator_loss = self.loss | |
else: | |
make_model = lambda: make_predictor('predictor', p_b or p, self.n_components) | |
get_loss = lambda predictor: tf.reduce_mean(self.compute_loss(target, predictor(B))) | |
self.updated_predictor = get_updated_model(self.predictor, make_model, get_loss, | |
n_steps=double_grad_steps, learning_rate=double_grad_lr, | |
model_variables=self.predictor_weights) | |
self.generator_loss = get_loss(self.updated_predictor) | |
self.generator_reg = tf.reduce_mean(tf.squared_difference(target, target_raw)) | |
with tf.variable_scope('gen_optimizer') as gen_optimizer_scope: | |
self.update_gen = gen_optimizer.minimize(-self.generator_loss + self.generator_reg, | |
var_list=self.generator_weights) | |
self.prediction, self.target = prediction, target | |
def compute_loss(self, target, prediction): | |
""" Return loss function for each sample and for component, output.shape == target.shape """ | |
return tf.squared_difference(target, prediction) | |
def fit(self, A, B, max_iters=10 ** 4, min_iters=0, tolerance=1e-4, batch_size=None, pred_steps=5, gen_steps=1, | |
reset_predictor=False, reorder=True, verbose=False, report_every=100): | |
""" | |
Trains the fox | |
:param pred_steps: predictor g(B) training iterations per one training step | |
:param gen_steps: generator f(A) training iterations per one training step | |
:param max_iters: maximum number of optimization steps till termination | |
:param min_iters: the guaranteed number of steps that algorithm will make without terminating by tolerance | |
:param tolerance: terminates if loss difference between 10-iteration cycles reaches this value | |
set to 0 to iterate for max_steps | |
:param reset_predictor: if True, resets predictor network after every step | |
:param reorder: if True, reorders components from highest loss to lowest | |
""" | |
sess = self.session | |
step = 0 | |
assert gen_steps > 0 | |
with sess.as_default(), sess.graph.as_default(): | |
initialize_uninitialized_variables(sess) | |
prev_loss = float('inf') | |
for batch_a, batch_b in iterate_minibatches(A, B, batch_size, cycle=True, shuffle=True): | |
step += 1 | |
if step > max_iters: break | |
feed = {self.A: batch_a, self.B: batch_b} | |
# update generator | |
for j in range(gen_steps): | |
loss_t_gen, _ = sess.run([self.generator_loss, self.update_gen], feed) | |
if reset_predictor: | |
sess.run(self.reset_pred) | |
for j in range(pred_steps): | |
loss_t_pred, _ = sess.run([self.loss, self.update_pred], feed) | |
# eval loss and metrics | |
if step % report_every == 0: | |
if pred_steps == 0: | |
loss_t_pred = sess.run(self.loss, feed) | |
loss_delta = abs(prev_loss - loss_t_gen) | |
prev_loss = loss_t_gen | |
if verbose: | |
if pred_steps == 0: | |
print("step %i; loss=%.4f; delta=%.4f" % (step, loss_t_gen, loss_delta)) | |
else: | |
print("step %i; loss(gen update)=%.4f; loss(pred update)=%.4f; delta=%.4f" % ( | |
step, loss_t_gen, loss_t_pred, loss_delta)) | |
if loss_delta < tolerance and step > min_iters: | |
if verbose: print("Done: reached target tolerance") | |
break | |
else: | |
if verbose: | |
print("Done: reached max steps") | |
# record components ordered by their loss value (highest to lowest) | |
if reorder: | |
if verbose: | |
print("Ordering components by loss values...") | |
self.loss_per_component = np.zeros([self.n_components]) | |
for batch_a, batch_b in iterate_minibatches(A, B, batch_size, cycle=False, shuffle=False): | |
batch_loss_values = sess.run(self.loss_values, {self.A: batch_a, self.B: batch_b}) | |
# ^-- [batch_size, n_components] | |
self.loss_per_component += batch_loss_values.sum(0) | |
self.component_order = np.argsort(-self.loss_per_component) | |
self.loss_per_component_ordered = self.loss_per_component[self.component_order] | |
else: | |
self.component_order = np.arange(self.n_components) | |
if verbose: | |
print("Training finished.") | |
return self | |
def predict(self, A=None, B=None, ordered=True, raw=False): | |
assert (A is None) != (B is None), "Please use either predict(A=...) or predict(B=...)" | |
sess = self.session | |
with sess.as_default(), sess.graph.as_default(): | |
if A is not None: | |
if not raw: | |
out = sess.run(self.target, {self.A: A}) | |
else: | |
out = sess.run(self.target_raw, {self.A: A}) | |
else: | |
out = sess.run(self.prediction, {self.B: B}) | |
if ordered: | |
out = out[:, self.component_order] | |
return out | |
def get_weights(self): | |
return self.session.run({'generator': self.generator.trainable_variables, | |
'predictor': self.predictor.trainable_variables}) | |
class DoubleFoxSolver(MinFoxSolver): | |
def __init__(self, n_components, p, p_b=None, double_grad_steps=10, double_grad_lr=0.01, | |
gen_optimizer=tf.train.AdamOptimizer(5e-4), **kwargs): | |
""" | |
A wrapper for MinFoxSolver that works by double gradient without any other predictor updates. | |
""" | |
dummy_opt = tf.train.GradientDescentOptimizer(learning_rate=0.0) | |
super().__init__(n_components, p, p_b, double_grad_steps=double_grad_steps, double_grad_lr=double_grad_lr, | |
gen_optimizer=gen_optimizer, pred_optimizer=dummy_opt, **kwargs) | |
def fit(self, A, B, max_iters=10 ** 4, min_iters=0, tolerance=1e-4, batch_size=None, | |
reset_predictor=True, reorder=True, verbose=False, report_every=100, **kwargs): | |
""" | |
A wrapper for MinFoxSolver that works by double gradient without any other predictor updates. | |
""" | |
super().fit(A, B, max_iters, min_iters, tolerance, batch_size=batch_size, | |
gen_steps=1, pred_steps=0, reset_predictor=reset_predictor, | |
reorder=reorder, verbose=verbose, report_every=report_every, **kwargs) | |
DoubleFoxSolver.__init__.__doc__ += MinFoxSolver.__init__.__doc__ | |
DoubleFoxSolver.fit.__doc__ += MinFoxSolver.fit.__doc__ | |
def orthogonalize_rows(matrix): | |
""" | |
Gram-shmidt orthogonalizer for each row of matrix; source: https://bit.ly/2FMOp40 | |
:param matrix: 2d float tensor [nrow, ncol] | |
:returns: row-orthogonalized matrix [nrow, ncol] s.t. | |
* output[i, :].dot(output[j, :]) ~= 0 for all i != j | |
* norm(output[i, :]) == 1 for all i | |
""" | |
basis = tf.expand_dims(matrix[0, :] / tf.norm(matrix[0, :]), 0) # [1, ncol] | |
for i in range(1, matrix.shape[0]): | |
v = tf.expand_dims(matrix[i, :], 0) # [1, ncol] | |
w = v - tf.matmul(tf.matmul(v, basis, transpose_b=True), basis) # [1, ncol] | |
basis = tf.concat([basis, w / tf.norm(w)], axis=0) # [i, ncol] | |
return basis | |
def orthogonalize_columns(matrix): | |
""" | |
Gram-shmidt orthogonalizer for each row of matrix; source: https://bit.ly/2FMOp40 | |
:param matrix: 2d float tensor [nrow, ncol] | |
:returns: column-orthogonalized matrix [nrow, ncol] s.t. | |
* output[:, i].dot(output[:, j]) ~= 0 for all i != j | |
* norm(output[:, j]) == 1 for all i | |
""" | |
basis = tf.expand_dims(matrix[:, 0] / tf.norm(matrix[:, 0]), 1) # [nrow, 1] | |
for i in range(1, matrix.shape[1]): | |
v = tf.expand_dims(matrix[:, i], 1) # [nrow, 1] | |
w = v - tf.matmul(basis, tf.matmul(basis, v, transpose_a=True)) # [nrow, 1] | |
basis = tf.concat([basis, w / tf.norm(w)], axis=1) # [nrow, i] | |
return basis | |
def initialize_uninitialized_variables(sess=None, var_list=None): | |
with tf.name_scope("initialize"): | |
sess = sess or tf.get_default_session() or tf.InteractiveSession() | |
uninitialized_names = set(sess.run(tf.report_uninitialized_variables(var_list))) | |
uninitialized_vars = [] | |
for var in tf.global_variables(): | |
if var.name[:-2].encode() in uninitialized_names: | |
uninitialized_vars.append(var) | |
sess.run(tf.variables_initializer(uninitialized_vars)) | |
def iterate_minibatches(x, y, batch_size=None, cycle=False, shuffle=False): | |
indices = np.arange(len(x)) | |
while True: | |
if batch_size is not None: | |
if shuffle: | |
indices = np.random.permutation(indices) | |
for batch_start in range(0, len(x), batch_size): | |
batch_ix = indices[batch_start: batch_start + batch_size] | |
yield x[batch_ix], y[batch_ix] | |
else: | |
yield x, y | |
if not cycle: | |
break | |
# Double grad utils | |
@contextmanager | |
def replace_variables(replacement_dict, strict=True, verbose=False, canonicalize_names=True, scope='', **kwargs): | |
""" A context that replaces all newly created tf variables using a replacement_dict{name -> value} """ | |
if canonicalize_names: | |
new_replacement_dict = {canonicalize(key): var for key, var in replacement_dict.items()} | |
assert len(new_replacement_dict) == len(replacement_dict), \ | |
"multiple variables got same canonic names, output: {}".format(new_replacement_dict) | |
replacement_dict = new_replacement_dict | |
def _custom_getter(getter, name, shape, *args, **kwargs): | |
name = canonicalize(name) if canonicalize_names else name | |
assert not strict or name in replacement_dict, "variable {} not found".format(name) | |
if name in replacement_dict: | |
if verbose: | |
print("variable {} replaced with {}".format(name, replacement_dict[name])) | |
return replacement_dict[name] | |
else: | |
if verbose: | |
print("variable {} not found, creating new".format(name)) | |
return getter(name = name, shape = shape, *args, **kwargs) | |
with tf.variable_scope(scope, custom_getter=_custom_getter, **kwargs): | |
yield | |
def canonicalize(name): | |
""" canonicalize varaible name: remove empty scopes (//) and colons """ | |
if ':' in name: | |
name = name[:name.index(':')] | |
while '//' in name: | |
name = name.replace('//', '/') | |
return name | |
def get_updated_model(initial_model, make_model, get_loss, n_steps=1, learning_rate=0.01, | |
model_variables=None, **kwargs): | |
""" | |
Performs in-graph SGD steps on the model. | |
:param initial_model: initial tfnn model (a sack of variables) | |
:param make_model: a function with no inputs that creates new model | |
The new model should "live" in the same variable scope as initial_model | |
:param get_loss: a function(model) -> scalar loss to be optimized | |
:param n_steps: number of gradient descent steps | |
:param learning_rate: sgd learning rate | |
:param model_variables: a list of model variables. defaults to all trainable variables in model scope | |
""" | |
model = initial_model | |
assert hasattr(model, 'scope') or model_variables is not None, \ | |
"model has no .scope, please add it or specify model_variables=[list_of_trainable_weights]" | |
model_variables = model_variables or tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, | |
scope=model.scope.name) | |
assert isinstance(model_variables, (list, tuple)) | |
variable_names = [canonicalize(var.name) for var in model_variables] | |
for step in range(n_steps): | |
grads = tf.gradients(get_loss(model), model_variables) | |
# Perform SGD update. Note: if you use adaptive optimizer (e.g. Adam), implement it HERE | |
updated_variables = { | |
name: (var - learning_rate * grad) if grad is not None else var | |
for name, var, grad in zip(variable_names, model_variables, grads) | |
} | |
with replace_variables(updated_variables, strict=True, **kwargs): | |
model = make_model() | |
model_variables = [updated_variables[name] for name in variable_names] | |
return model |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment