Skip to content

Instantly share code, notes, and snippets.

@justheuristic
Created January 24, 2019 09:52
Show Gist options
  • Save justheuristic/1e90f65371bed3073ad9fb1abb957c42 to your computer and use it in GitHub Desktop.
Save justheuristic/1e90f65371bed3073ad9fb1abb957c42 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Doublefox example\n",
"![Img](https://i.imgur.com/IkU0BxB.png?2)\n",
"\n",
"__Patch notes:__\n",
"* Добавились двойные градиенты\n",
" * `MinFoxSolver(..., double_grad_steps)` - двойные градиенты + обычный minmax update\n",
" * `DoubleFoxSolver(...)` - только двойные градиенты и ничего лишнего\n",
"* Заменил while_loop на unroll, *внезапно* оно заработало (видимо баг в tf1.3 control flow ops)\n",
"* Предиктор обновляется с sgd; если использовать adam, работают нестабильно\n",
"\n",
"\n",
"* Теперь можно сочетать двойные градиенты и обычный minmax update. По моим приборам это сильно ускоряет сходимость если predictor многослойный.\n",
" * `MinFoxSolver(..., double_grad_steps=N).fit(..., pred_steps=1)` или больше\n",
" * предиктор будет на каждом шаге приходить в чуть более точное начальное приближение\n",
"\n",
"* Tech stuff\n",
" * появился параметр min_iters - гарантированное минимальное число итераций без проверки tolerance\n",
" * теперь генератор/предиктор обязаны быть моделями tfnn, а не keras. С керасом двойные градиенты не завелись.\n",
" * если предиктор не обучается, выводится лосс только для генератора (в точке \"после double grad обучения предиктора\"), если обучается - оба лосса.\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"%matplotlib inline\n",
"import numpy as np\n",
"\n",
"num_samples, true_dim, projected_dim, components = 2000, 20, 50, 10\n",
"\n",
"M = np.random.randn(num_samples, true_dim).astype('float32')\n",
"A = M.dot(np.random.randn(true_dim, projected_dim).astype('float32'))\n",
"B = M[:, :-components].dot(np.random.randn(true_dim - components, projected_dim).astype('float32'))"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from minifox import MinFoxSolver, DoubleFoxSolver"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"step 10; loss=0.6451; delta=inf\n",
"step 20; loss=0.6223; delta=0.0228\n",
"step 30; loss=0.5927; delta=0.0296\n",
"step 40; loss=0.5926; delta=0.0001\n",
"step 50; loss=0.6496; delta=0.0570\n",
"step 60; loss=0.5735; delta=0.0762\n",
"step 70; loss=0.6472; delta=0.0737\n",
"step 80; loss=0.6011; delta=0.0461\n",
"step 90; loss=0.6390; delta=0.0379\n",
"step 100; loss=0.6055; delta=0.0336\n",
"step 110; loss=0.6315; delta=0.0260\n",
"step 120; loss=0.7394; delta=0.1080\n",
"step 130; loss=0.6500; delta=0.0894\n",
"step 140; loss=0.7142; delta=0.0642\n",
"step 150; loss=0.7392; delta=0.0251\n",
"step 160; loss=0.8023; delta=0.0630\n",
"step 170; loss=0.8075; delta=0.0052\n",
"step 180; loss=0.8172; delta=0.0097\n",
"step 190; loss=0.9127; delta=0.0956\n",
"step 200; loss=0.7946; delta=0.1181\n",
"step 210; loss=0.8122; delta=0.0176\n",
"step 220; loss=0.8569; delta=0.0447\n",
"step 230; loss=0.9695; delta=0.1126\n",
"step 240; loss=1.0084; delta=0.0389\n",
"step 250; loss=0.8318; delta=0.1766\n",
"step 260; loss=0.9548; delta=0.1230\n",
"step 270; loss=0.9658; delta=0.0110\n",
"step 280; loss=0.9637; delta=0.0021\n",
"step 290; loss=1.0542; delta=0.0906\n",
"step 300; loss=0.8983; delta=0.1559\n",
"step 310; loss=1.0395; delta=0.1412\n",
"step 320; loss=0.9934; delta=0.0461\n",
"step 330; loss=0.9783; delta=0.0151\n",
"step 340; loss=1.1710; delta=0.1927\n",
"step 350; loss=1.0778; delta=0.0931\n",
"step 360; loss=1.0159; delta=0.0620\n",
"step 370; loss=1.1092; delta=0.0933\n",
"step 380; loss=1.0549; delta=0.0543\n",
"step 390; loss=1.0246; delta=0.0303\n",
"step 400; loss=0.9764; delta=0.0482\n",
"step 410; loss=1.0605; delta=0.0842\n",
"step 420; loss=1.0842; delta=0.0236\n",
"step 430; loss=1.0785; delta=0.0057\n",
"step 440; loss=1.0400; delta=0.0385\n",
"step 450; loss=1.0236; delta=0.0164\n",
"step 460; loss=1.1403; delta=0.1167\n",
"step 470; loss=1.0166; delta=0.1238\n",
"step 480; loss=1.0628; delta=0.0463\n",
"step 490; loss=1.0529; delta=0.0099\n",
"step 500; loss=1.2188; delta=0.1659\n",
"step 510; loss=1.0563; delta=0.1625\n",
"step 520; loss=1.1194; delta=0.0631\n",
"step 530; loss=1.0821; delta=0.0373\n",
"step 540; loss=1.0635; delta=0.0186\n",
"step 550; loss=1.0930; delta=0.0295\n",
"step 560; loss=1.0918; delta=0.0011\n",
"step 570; loss=1.1540; delta=0.0621\n",
"step 580; loss=1.1254; delta=0.0286\n",
"step 590; loss=1.0980; delta=0.0274\n",
"step 600; loss=1.1050; delta=0.0070\n",
"step 610; loss=1.0342; delta=0.0708\n",
"step 620; loss=1.0396; delta=0.0054\n",
"step 630; loss=1.0488; delta=0.0092\n",
"step 640; loss=1.1031; delta=0.0543\n",
"step 650; loss=1.0539; delta=0.0492\n",
"step 660; loss=1.0860; delta=0.0321\n",
"step 670; loss=1.0799; delta=0.0060\n",
"step 680; loss=1.0947; delta=0.0148\n",
"step 690; loss=1.0840; delta=0.0107\n",
"step 700; loss=1.0348; delta=0.0492\n",
"step 710; loss=1.0513; delta=0.0165\n",
"step 720; loss=1.1529; delta=0.1016\n",
"step 730; loss=1.1441; delta=0.0087\n",
"step 740; loss=1.1120; delta=0.0321\n",
"step 750; loss=1.0421; delta=0.0699\n",
"step 760; loss=1.1786; delta=0.1365\n",
"step 770; loss=1.0416; delta=0.1370\n",
"step 780; loss=1.0709; delta=0.0293\n",
"step 790; loss=1.0629; delta=0.0080\n",
"step 800; loss=1.0460; delta=0.0169\n",
"step 810; loss=1.0882; delta=0.0421\n",
"step 820; loss=1.1138; delta=0.0256\n",
"step 830; loss=1.0700; delta=0.0438\n",
"step 840; loss=1.0305; delta=0.0395\n",
"step 850; loss=1.1153; delta=0.0847\n",
"step 860; loss=1.0459; delta=0.0694\n",
"step 870; loss=1.0974; delta=0.0515\n",
"step 880; loss=1.1377; delta=0.0403\n",
"step 890; loss=1.1408; delta=0.0031\n",
"step 900; loss=1.1732; delta=0.0324\n",
"step 910; loss=1.0606; delta=0.1126\n",
"step 920; loss=1.1232; delta=0.0627\n",
"step 930; loss=1.1775; delta=0.0543\n",
"step 940; loss=1.0712; delta=0.1064\n",
"step 950; loss=1.0966; delta=0.0254\n",
"step 960; loss=1.1672; delta=0.0706\n",
"step 970; loss=1.0579; delta=0.1093\n",
"step 980; loss=1.1573; delta=0.0994\n",
"step 990; loss=1.0828; delta=0.0745\n",
"step 1000; loss=1.1045; delta=0.0217\n",
"step 1010; loss=1.1124; delta=0.0079\n",
"step 1020; loss=1.0317; delta=0.0806\n",
"step 1030; loss=1.1844; delta=0.1527\n",
"step 1040; loss=1.0541; delta=0.1303\n",
"step 1050; loss=1.0499; delta=0.0042\n",
"step 1060; loss=1.1265; delta=0.0766\n",
"step 1070; loss=1.0569; delta=0.0696\n",
"step 1080; loss=1.0000; delta=0.0568\n",
"step 1090; loss=1.1417; delta=0.1416\n",
"step 1100; loss=1.0832; delta=0.0584\n",
"step 1110; loss=1.0832; delta=0.0001\n",
"Done: reached target tolerance\n",
"Ordering components by loss values...\n",
"Training finished.\n",
"CPU times: user 3min 19s, sys: 12 s, total: 3min 31s\n",
"Wall time: 2min\n"
]
}
],
"source": [
"%%time\n",
"fox = DoubleFoxSolver(n_components=components, p=projected_dim, \n",
" double_grad_steps=20, double_grad_lr=0.01)\n",
"fox.fit(A, B,\n",
" batch_size=None, # None means full data\n",
" max_iters=10000, tolerance=1e-4,\n",
" verbose=True, report_every=10)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.image.AxesImage at 0x7f001ef20fd0>"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAPgAAAD8CAYAAABaQGkdAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAACeNJREFUeJzt3N+L5XUdx/Hnq9lVWyvs1427S2sQlQRlDGIKXWjQT/KmCwODvNmbMo1AtJv+gRC7iGAxvVHyYvMiQrLox0U3W+Mq1ToVYqWbhluQhZG71ruLmWAzd853dr5fv3PePB8g7MyePb5Y5rnfc86c+aSqkNTTa+YeIGk6Bi41ZuBSYwYuNWbgUmMGLjVm4FJjBi41ZuBSY3umuNO3vGmlDh3cO/r9/vYX+0a/T2kZ/ZMXOF0vZtHtJgn80MG9/Ozhg6Pf74cvfd/o9ykto2P1w0G38yG61JiBS40ZuNSYgUuNGbjUmIFLjQ0KPMlHkvwmyRNJbp96lKRxLAw8yQrwdeCjwOXAp5NcPvUwSTs35Ap+JfBEVT1ZVaeBB4Drp50laQxDAt8PPH3Wxyc3P/c/khxOspZk7dRf/jXWPkk7MCTwV3q/6/8dxVpVR6pqtapW3/rmlZ0vk7RjQwI/CZz9xvIDwDPTzJE0piGB/xx4R5LLklwA3AB8Z9pZksaw8KfJquqlJJ8HHgZWgHuq6sTkyyTt2KAfF62qh4CHJt4iaWS+k01qzMClxgxcaszApcYMXGpskkMXf/uLfZMckPjwM4+Nfp/gYY7qyyu41JiBS40ZuNSYgUuNGbjUmIFLjRm41JiBS40ZuNSYgUuNGbjUmIFLjRm41JiBS40ZuNSYgUuNGbjUmIFLjRm41JiBS40ZuNTYJKeqTmWq00+nOK3Vk1q1G3gFlxozcKkxA5caM3CpMQOXGjNwqbGFgSc5mOTHSdaTnEhyy6sxTNLODfk++EvAl6rqeJLXA48k+UFVPT7xNkk7tPAKXlXPVtXxzV//HVgH9k89TNLObes5eJJDwBXAsSnGSBrX4LeqJnkd8G3g1qr62yv8/mHgMMBF7BttoKTzN+gKnmQvG3HfX1UPvtJtqupIVa1W1epeLhxzo6TzNORV9ADfBNar6s7pJ0kay5Ar+DXAZ4Brkzy2+d/HJt4laQQLn4NX1U+BvApbJI3Md7JJjRm41JiBS40ZuNSYgUuNLdWhi1OZ4oDEKQ5yBA9z1PZ4BZcaM3CpMQOXGjNwqTEDlxozcKkxA5caM3CpMQOXGjNwqTEDlxozcKkxA5caM3CpMQOXGjNwqTEDlxozcKkxA5caM3CpMQOXGvNU1YlMdfqpp7VqO7yCS40ZuNSYgUuNGbjUmIFLjRm41JiBS40NDjzJSpJHk3x3ykGSxrOdK/gtwPpUQySNb1DgSQ4AHwfunnaOpDENvYLfBdwG/PtcN0hyOMlakrUzvDjKOEk7szDwJJ8AnquqR7a6XVUdqarVqlrdy4WjDZR0/oZcwa8BPpnk98ADwLVJ7pt0laRRLAy8qu6oqgNVdQi4AfhRVd04+TJJO+b3waXGtvXz4FX1E+AnkyyRNDqv4FJjBi41ZuBSYwYuNWbgUmOeqrpkPK1V2+EVXGrMwKXGDFxqzMClxgxcaszApcYMXGrMwKXGDFxqzMClxgxcaszApcYMXGrMwKXGDFxqzMClxgxcaszApcYMXGrMwKXGDFxqzFNVBXhaa1dewaXGDFxqzMClxgxcaszApcYMXGpsUOBJLklyNMmvk6wn+cDUwyTt3NDvg38N+F5VfSrJBcC+CTdJGsnCwJO8Afgg8FmAqjoNnJ52lqQxDHmI/nbgFHBvkkeT3J3k4ol3SRrBkMD3AO8HvlFVVwAvALe//EZJDidZS7J2hhdHninpfAwJ/CRwsqqObX58lI3g/0dVHamq1apa3cuFY26UdJ4WBl5VfwKeTvLOzU9dBzw+6SpJoxj6KvrNwP2br6A/Cdw03SRJYxkUeFU9BqxOvEXSyHwnm9SYgUuNGbjUmIFLjRm41JiBS415qqomtUyntXY8qdUruNSYgUuNGbjUmIFLjRm41JiBS40ZuNSYgUuNGbjUmIFLjRm41JiBS40ZuNSYgUuNGbjUmIFLjRm41JiBS40ZuNSYgUuNeeiiltIUByROcZAjzHuYo1dwqTEDlxozcKkxA5caM3CpMQOXGjNwqbFBgSf5YpITSX6V5FtJLpp6mKSdWxh4kv3AF4DVqnoPsALcMPUwSTs39CH6HuC1SfYA+4BnppskaSwLA6+qPwJfBZ4CngWer6rvv/x2SQ4nWUuydoYXx18qaduGPER/I3A9cBlwKXBxkhtffruqOlJVq1W1upcLx18qaduGPET/EPC7qjpVVWeAB4Grp50laQxDAn8KuCrJviQBrgPWp50laQxDnoMfA44Cx4Ffbv6ZIxPvkjSCQT8PXlVfAb4y8RZJI/OdbFJjBi41ZuBSYwYuNWbgUmOeqiptmur00ylOa73yw/8YdDuv4FJjBi41ZuBSYwYuNWbgUmMGLjVm4FJjBi41ZuBSYwYuNWbgUmMGLjVm4FJjBi41ZuBSYwYuNWbgUmMGLjVm4FJjBi41ZuBSY6mq8e80OQX8YcBN3wL8efQB01mmvcu0FZZr727Y+raqeuuiG00S+FBJ1qpqdbYB27RMe5dpKyzX3mXa6kN0qTEDlxqbO/AjM///t2uZ9i7TVliuvUuzddbn4JKmNfcVXNKEZgs8yUeS/CbJE0lun2vHIkkOJvlxkvUkJ5LcMvemIZKsJHk0yXfn3rKVJJckOZrk15t/xx+Ye9NWknxx8+vgV0m+leSiuTdtZZbAk6wAXwc+ClwOfDrJ5XNsGeAl4EtV9W7gKuBzu3jr2W4B1uceMcDXgO9V1buA97KLNyfZD3wBWK2q9wArwA3zrtraXFfwK4EnqurJqjoNPABcP9OWLVXVs1V1fPPXf2fjC3D/vKu2luQA8HHg7rm3bCXJG4APAt8EqKrTVfXXeVcttAd4bZI9wD7gmZn3bGmuwPcDT5/18Ul2eTQASQ4BVwDH5l2y0F3AbcC/5x6ywNuBU8C9m08n7k5y8dyjzqWq/gh8FXgKeBZ4vqq+P++qrc0VeF7hc7v65fwkrwO+DdxaVX+be8+5JPkE8FxVPTL3lgH2AO8HvlFVVwAvALv59Zg3svFI8zLgUuDiJDfOu2prcwV+Ejh41scH2MUPdZLsZSPu+6vqwbn3LHAN8Mkkv2fjqc+1Se6bd9I5nQROVtV/HxEdZSP43epDwO+q6lRVnQEeBK6eedOW5gr858A7klyW5AI2Xqj4zkxbtpQkbDxHXK+qO+fes0hV3VFVB6rqEBt/rz+qql15lamqPwFPJ3nn5qeuAx6fcdIiTwFXJdm3+XVxHbv4RUHYeIj0qquql5J8HniYjVci76mqE3NsGeAa4DPAL5M8tvm5L1fVQzNu6uRm4P7Nf+ifBG6aec85VdWxJEeB42x8d+VRdvm72nwnm9SY72STGjNwqTEDlxozcKkxA5caM3CpMQOXGjNwqbH/AAcfJCfPJk4KAAAAAElFTkSuQmCC\n",
"text/plain": [
"<matplotlib.figure.Figure at 0x7f001efd8a90>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"vectors = fox.predict(A=A)\n",
"\n",
"cosine = lambda a, b: (a * b).sum(-1) / (a * a).sum(-1) ** 0.5 / (b * b).sum(-1) ** 0.5\n",
"correlation_map = cosine(vectors.T[:, None, :], vectors.T[None, :, :])\n",
"\n",
"# orthogonality check\n",
"plt.imshow(correlation_map)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"coefficient of determination from known cols: 0.00045179598686428474\n",
"coefficient of determination from unknown cols: 0.9942116432725838\n"
]
}
],
"source": [
"from sklearn.linear_model import LinearRegression\n",
"regression = LinearRegression().fit(M[:, :-components], vectors)\n",
"print(\"coefficient of determination from known cols:\", regression.score(M[:, :-components], vectors))\n",
"regression = LinearRegression().fit(M[:, -components:], vectors)\n",
"print(\"coefficient of determination from unknown cols:\", regression.score(M[:, -components:], vectors))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Эмпирически \"хорошие\" параметры"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"step 10; loss(gen update)=0.5658; loss(pred update)=0.5691; delta=inf\n",
"step 20; loss(gen update)=0.5614; loss(pred update)=0.5631; delta=0.0044\n",
"step 30; loss(gen update)=0.5660; loss(pred update)=0.5676; delta=0.0046\n",
"step 40; loss(gen update)=0.5715; loss(pred update)=0.5730; delta=0.0055\n",
"step 50; loss(gen update)=0.5781; loss(pred update)=0.5797; delta=0.0066\n",
"step 60; loss(gen update)=0.5865; loss(pred update)=0.5881; delta=0.0084\n",
"step 70; loss(gen update)=0.5971; loss(pred update)=0.5989; delta=0.0106\n",
"step 80; loss(gen update)=0.6098; loss(pred update)=0.6117; delta=0.0127\n",
"step 90; loss(gen update)=0.6236; loss(pred update)=0.6255; delta=0.0138\n",
"step 100; loss(gen update)=0.6376; loss(pred update)=0.6395; delta=0.0140\n",
"step 110; loss(gen update)=0.6519; loss(pred update)=0.6537; delta=0.0142\n",
"step 120; loss(gen update)=0.6669; loss(pred update)=0.6688; delta=0.0150\n",
"step 130; loss(gen update)=0.6832; loss(pred update)=0.6852; delta=0.0163\n",
"step 140; loss(gen update)=0.7010; loss(pred update)=0.7031; delta=0.0178\n",
"step 150; loss(gen update)=0.7201; loss(pred update)=0.7223; delta=0.0191\n",
"step 160; loss(gen update)=0.7403; loss(pred update)=0.7425; delta=0.0202\n",
"step 170; loss(gen update)=0.7613; loss(pred update)=0.7635; delta=0.0210\n",
"step 180; loss(gen update)=0.7825; loss(pred update)=0.7847; delta=0.0212\n",
"step 190; loss(gen update)=0.8035; loss(pred update)=0.8057; delta=0.0210\n",
"step 200; loss(gen update)=0.8239; loss(pred update)=0.8260; delta=0.0204\n",
"step 210; loss(gen update)=0.8433; loss(pred update)=0.8452; delta=0.0194\n",
"step 220; loss(gen update)=0.8615; loss(pred update)=0.8633; delta=0.0182\n",
"step 230; loss(gen update)=0.8784; loss(pred update)=0.8800; delta=0.0169\n",
"step 240; loss(gen update)=0.8939; loss(pred update)=0.8955; delta=0.0156\n",
"step 250; loss(gen update)=0.9082; loss(pred update)=0.9096; delta=0.0143\n",
"step 260; loss(gen update)=0.9212; loss(pred update)=0.9225; delta=0.0130\n",
"step 270; loss(gen update)=0.9329; loss(pred update)=0.9340; delta=0.0117\n",
"step 280; loss(gen update)=0.9433; loss(pred update)=0.9443; delta=0.0104\n",
"step 290; loss(gen update)=0.9525; loss(pred update)=0.9534; delta=0.0092\n",
"step 300; loss(gen update)=0.9605; loss(pred update)=0.9613; delta=0.0080\n",
"step 310; loss(gen update)=0.9674; loss(pred update)=0.9680; delta=0.0069\n",
"step 320; loss(gen update)=0.9733; loss(pred update)=0.9738; delta=0.0059\n",
"step 330; loss(gen update)=0.9782; loss(pred update)=0.9786; delta=0.0049\n",
"step 340; loss(gen update)=0.9822; loss(pred update)=0.9826; delta=0.0041\n",
"step 350; loss(gen update)=0.9856; loss(pred update)=0.9859; delta=0.0033\n",
"step 360; loss(gen update)=0.9883; loss(pred update)=0.9886; delta=0.0027\n",
"step 370; loss(gen update)=0.9905; loss(pred update)=0.9907; delta=0.0022\n",
"step 380; loss(gen update)=0.9923; loss(pred update)=0.9925; delta=0.0018\n",
"step 390; loss(gen update)=0.9937; loss(pred update)=0.9939; delta=0.0014\n",
"step 400; loss(gen update)=0.9949; loss(pred update)=0.9950; delta=0.0012\n",
"step 410; loss(gen update)=0.9958; loss(pred update)=0.9959; delta=0.0009\n",
"step 420; loss(gen update)=0.9965; loss(pred update)=0.9966; delta=0.0007\n",
"step 430; loss(gen update)=0.9971; loss(pred update)=0.9972; delta=0.0006\n",
"step 440; loss(gen update)=0.9976; loss(pred update)=0.9977; delta=0.0005\n",
"step 450; loss(gen update)=0.9980; loss(pred update)=0.9981; delta=0.0004\n",
"step 460; loss(gen update)=0.9983; loss(pred update)=0.9984; delta=0.0003\n",
"step 470; loss(gen update)=0.9986; loss(pred update)=0.9986; delta=0.0003\n",
"step 480; loss(gen update)=0.9988; loss(pred update)=0.9989; delta=0.0002\n",
"step 490; loss(gen update)=0.9990; loss(pred update)=0.9990; delta=0.0002\n",
"step 500; loss(gen update)=0.9992; loss(pred update)=0.9992; delta=0.0002\n",
"step 510; loss(gen update)=0.9993; loss(pred update)=0.9993; delta=0.0001\n",
"step 520; loss(gen update)=0.9994; loss(pred update)=0.9994; delta=0.0001\n",
"step 530; loss(gen update)=0.9995; loss(pred update)=0.9995; delta=0.0001\n",
"Done: reached target tolerance\n",
"Ordering components by loss values...\n",
"Training finished.\n",
"CPU times: user 1min, sys: 2.84 s, total: 1min 3s\n",
"Wall time: 37 s\n"
]
}
],
"source": [
"%%time\n",
"import tensorflow as tf\n",
"\n",
"fox = MinFoxSolver(n_components=components, p=projected_dim,\n",
" gen_optimizer=tf.train.AdamOptimizer(1e-3),\n",
" double_grad_steps=5, double_grad_lr=0.01)\n",
"fox.fit(A, B,\n",
" batch_size=None, # None means full data\n",
" pred_steps=5,\n",
" max_iters=10000, tolerance=1e-4,\n",
" verbose=True, report_every=10)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.2"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
import numpy as np
from contextlib import contextmanager
from uuid import uuid4
import tensorflow as tf
import tfnn.layers.basic as L
class MinFoxSolver:
def __init__(self, n_components, p, p_b=None, double_grad_steps=0, double_grad_lr=0.01,
gen_optimizer=tf.train.AdamOptimizer(5e-4), pred_optimizer=tf.train.AdamOptimizer(5e-4),
make_generator=lambda name, inp_size, out_size: L.Dense(name, inp_size, out_size, activ=lambda x: x),
make_predictor=lambda name, inp_size, out_size: L.Dense(name, inp_size, out_size, activ=lambda x: x,
matrix_initializer=tf.zeros_initializer()),
sess=None, device=None,
):
"""
Given two matrices A and B, predict a variable f(A) that is impossible to predict from matrix B
:param p: last dimension of A
:param p_b: dimension of B, default p_b = p
:param gen_optimizer: tf optimizer used to train generator
:param pred_optimizer: tf optimizer used to train predictor out-of-graph (between iterations)
:param make_generator: callback to create a keras model for target variable generator given A
:param make_predictor: callback to create a keras model for target variable predictor given B
Both models returned by make_generator and make_predictor should have
all their trainable variables created inside name scope (first arg)
:param sess: tensorflow session. If not specified, uses default session or creates new one.
:param device: 'cpu', 'gpu' or a specific tf device to run on.
/* Маааленькая лисёнка */
"""
config = tf.ConfigProto(device_count={'GPU': int(device != 'cpu')}) if device is not None else tf.ConfigProto()
self.session = sess = sess or tf.get_default_session() or tf.Session()
self.n_components = n_components
self.gen_optimizer, self.pred_optimizer = gen_optimizer, pred_optimizer
with sess.as_default(), sess.graph.as_default(), tf.device(device), tf.variable_scope(str(uuid4())) as self.scope:
A = self.A = tf.placeholder(tf.float32, [None, p])
B = self.B = tf.placeholder(tf.float32, [None, p_b or p])
self.generator = make_generator('generator', p, n_components)
self.predictor = make_predictor('predictor', p_b or p, n_components)
prefix = tf.get_variable_scope().name + '/'
self.generator_weights = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
scope=prefix + 'generator')
self.predictor_weights = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
scope=prefix + 'predictor')
prediction = self.predictor(B)
target_raw = self.generator(A)
# orthogonalize target and scale to unit norm
target = orthogonalize_columns(target_raw)
target *= tf.sqrt(tf.to_float(tf.shape(target)[0]))
self.loss_values = self.compute_loss(target, prediction)
self.loss = tf.reduce_mean(self.loss_values)
# update predictor
with tf.variable_scope('pred_optimizer') as pred_optimizer_scope:
self.update_pred = pred_optimizer.minimize(self.loss,
var_list=self.predictor_weights)
pred_state = list(self.predictor_weights)
pred_state += tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=tf.get_variable_scope().name)
self.reset_pred = tf.variables_initializer(pred_state)
# update generator using (possibly) ingraph-updated generator
if double_grad_steps == 0:
generator_loss = self.loss
else:
make_model = lambda: make_predictor('predictor', p_b or p, self.n_components)
get_loss = lambda predictor: tf.reduce_mean(self.compute_loss(target, predictor(B)))
self.updated_predictor = get_updated_model(self.predictor, make_model, get_loss,
n_steps=double_grad_steps, learning_rate=double_grad_lr,
model_variables=self.predictor_weights)
self.generator_loss = get_loss(self.updated_predictor)
self.generator_reg = tf.reduce_mean(tf.squared_difference(target, target_raw))
with tf.variable_scope('gen_optimizer') as gen_optimizer_scope:
self.update_gen = gen_optimizer.minimize(-self.generator_loss + self.generator_reg,
var_list=self.generator_weights)
self.prediction, self.target = prediction, target
def compute_loss(self, target, prediction):
""" Return loss function for each sample and for component, output.shape == target.shape """
return tf.squared_difference(target, prediction)
def fit(self, A, B, max_iters=10 ** 4, min_iters=0, tolerance=1e-4, batch_size=None, pred_steps=5, gen_steps=1,
reset_predictor=False, reorder=True, verbose=False, report_every=100):
"""
Trains the fox
:param pred_steps: predictor g(B) training iterations per one training step
:param gen_steps: generator f(A) training iterations per one training step
:param max_iters: maximum number of optimization steps till termination
:param min_iters: the guaranteed number of steps that algorithm will make without terminating by tolerance
:param tolerance: terminates if loss difference between 10-iteration cycles reaches this value
set to 0 to iterate for max_steps
:param reset_predictor: if True, resets predictor network after every step
:param reorder: if True, reorders components from highest loss to lowest
"""
sess = self.session
step = 0
assert gen_steps > 0
with sess.as_default(), sess.graph.as_default():
initialize_uninitialized_variables(sess)
prev_loss = float('inf')
for batch_a, batch_b in iterate_minibatches(A, B, batch_size, cycle=True, shuffle=True):
step += 1
if step > max_iters: break
feed = {self.A: batch_a, self.B: batch_b}
# update generator
for j in range(gen_steps):
loss_t_gen, _ = sess.run([self.generator_loss, self.update_gen], feed)
if reset_predictor:
sess.run(self.reset_pred)
for j in range(pred_steps):
loss_t_pred, _ = sess.run([self.loss, self.update_pred], feed)
# eval loss and metrics
if step % report_every == 0:
if pred_steps == 0:
loss_t_pred = sess.run(self.loss, feed)
loss_delta = abs(prev_loss - loss_t_gen)
prev_loss = loss_t_gen
if verbose:
if pred_steps == 0:
print("step %i; loss=%.4f; delta=%.4f" % (step, loss_t_gen, loss_delta))
else:
print("step %i; loss(gen update)=%.4f; loss(pred update)=%.4f; delta=%.4f" % (
step, loss_t_gen, loss_t_pred, loss_delta))
if loss_delta < tolerance and step > min_iters:
if verbose: print("Done: reached target tolerance")
break
else:
if verbose:
print("Done: reached max steps")
# record components ordered by their loss value (highest to lowest)
if reorder:
if verbose:
print("Ordering components by loss values...")
self.loss_per_component = np.zeros([self.n_components])
for batch_a, batch_b in iterate_minibatches(A, B, batch_size, cycle=False, shuffle=False):
batch_loss_values = sess.run(self.loss_values, {self.A: batch_a, self.B: batch_b})
# ^-- [batch_size, n_components]
self.loss_per_component += batch_loss_values.sum(0)
self.component_order = np.argsort(-self.loss_per_component)
self.loss_per_component_ordered = self.loss_per_component[self.component_order]
else:
self.component_order = np.arange(self.n_components)
if verbose:
print("Training finished.")
return self
def predict(self, A=None, B=None, ordered=True, raw=False):
assert (A is None) != (B is None), "Please use either predict(A=...) or predict(B=...)"
sess = self.session
with sess.as_default(), sess.graph.as_default():
if A is not None:
if not raw:
out = sess.run(self.target, {self.A: A})
else:
out = sess.run(self.target_raw, {self.A: A})
else:
out = sess.run(self.prediction, {self.B: B})
if ordered:
out = out[:, self.component_order]
return out
def get_weights(self):
return self.session.run({'generator': self.generator.trainable_variables,
'predictor': self.predictor.trainable_variables})
class DoubleFoxSolver(MinFoxSolver):
def __init__(self, n_components, p, p_b=None, double_grad_steps=10, double_grad_lr=0.01,
gen_optimizer=tf.train.AdamOptimizer(5e-4), **kwargs):
"""
A wrapper for MinFoxSolver that works by double gradient without any other predictor updates.
"""
dummy_opt = tf.train.GradientDescentOptimizer(learning_rate=0.0)
super().__init__(n_components, p, p_b, double_grad_steps=double_grad_steps, double_grad_lr=double_grad_lr,
gen_optimizer=gen_optimizer, pred_optimizer=dummy_opt, **kwargs)
def fit(self, A, B, max_iters=10 ** 4, min_iters=0, tolerance=1e-4, batch_size=None,
reset_predictor=True, reorder=True, verbose=False, report_every=100, **kwargs):
"""
A wrapper for MinFoxSolver that works by double gradient without any other predictor updates.
"""
super().fit(A, B, max_iters, min_iters, tolerance, batch_size=batch_size,
gen_steps=1, pred_steps=0, reset_predictor=reset_predictor,
reorder=reorder, verbose=verbose, report_every=report_every, **kwargs)
DoubleFoxSolver.__init__.__doc__ += MinFoxSolver.__init__.__doc__
DoubleFoxSolver.fit.__doc__ += MinFoxSolver.fit.__doc__
def orthogonalize_rows(matrix):
"""
Gram-shmidt orthogonalizer for each row of matrix; source: https://bit.ly/2FMOp40
:param matrix: 2d float tensor [nrow, ncol]
:returns: row-orthogonalized matrix [nrow, ncol] s.t.
* output[i, :].dot(output[j, :]) ~= 0 for all i != j
* norm(output[i, :]) == 1 for all i
"""
basis = tf.expand_dims(matrix[0, :] / tf.norm(matrix[0, :]), 0) # [1, ncol]
for i in range(1, matrix.shape[0]):
v = tf.expand_dims(matrix[i, :], 0) # [1, ncol]
w = v - tf.matmul(tf.matmul(v, basis, transpose_b=True), basis) # [1, ncol]
basis = tf.concat([basis, w / tf.norm(w)], axis=0) # [i, ncol]
return basis
def orthogonalize_columns(matrix):
"""
Gram-shmidt orthogonalizer for each row of matrix; source: https://bit.ly/2FMOp40
:param matrix: 2d float tensor [nrow, ncol]
:returns: column-orthogonalized matrix [nrow, ncol] s.t.
* output[:, i].dot(output[:, j]) ~= 0 for all i != j
* norm(output[:, j]) == 1 for all i
"""
basis = tf.expand_dims(matrix[:, 0] / tf.norm(matrix[:, 0]), 1) # [nrow, 1]
for i in range(1, matrix.shape[1]):
v = tf.expand_dims(matrix[:, i], 1) # [nrow, 1]
w = v - tf.matmul(basis, tf.matmul(basis, v, transpose_a=True)) # [nrow, 1]
basis = tf.concat([basis, w / tf.norm(w)], axis=1) # [nrow, i]
return basis
def initialize_uninitialized_variables(sess=None, var_list=None):
with tf.name_scope("initialize"):
sess = sess or tf.get_default_session() or tf.InteractiveSession()
uninitialized_names = set(sess.run(tf.report_uninitialized_variables(var_list)))
uninitialized_vars = []
for var in tf.global_variables():
if var.name[:-2].encode() in uninitialized_names:
uninitialized_vars.append(var)
sess.run(tf.variables_initializer(uninitialized_vars))
def iterate_minibatches(x, y, batch_size=None, cycle=False, shuffle=False):
indices = np.arange(len(x))
while True:
if batch_size is not None:
if shuffle:
indices = np.random.permutation(indices)
for batch_start in range(0, len(x), batch_size):
batch_ix = indices[batch_start: batch_start + batch_size]
yield x[batch_ix], y[batch_ix]
else:
yield x, y
if not cycle:
break
# Double grad utils
@contextmanager
def replace_variables(replacement_dict, strict=True, verbose=False, canonicalize_names=True, scope='', **kwargs):
""" A context that replaces all newly created tf variables using a replacement_dict{name -> value} """
if canonicalize_names:
new_replacement_dict = {canonicalize(key): var for key, var in replacement_dict.items()}
assert len(new_replacement_dict) == len(replacement_dict), \
"multiple variables got same canonic names, output: {}".format(new_replacement_dict)
replacement_dict = new_replacement_dict
def _custom_getter(getter, name, shape, *args, **kwargs):
name = canonicalize(name) if canonicalize_names else name
assert not strict or name in replacement_dict, "variable {} not found".format(name)
if name in replacement_dict:
if verbose:
print("variable {} replaced with {}".format(name, replacement_dict[name]))
return replacement_dict[name]
else:
if verbose:
print("variable {} not found, creating new".format(name))
return getter(name = name, shape = shape, *args, **kwargs)
with tf.variable_scope(scope, custom_getter=_custom_getter, **kwargs):
yield
def canonicalize(name):
""" canonicalize varaible name: remove empty scopes (//) and colons """
if ':' in name:
name = name[:name.index(':')]
while '//' in name:
name = name.replace('//', '/')
return name
def get_updated_model(initial_model, make_model, get_loss, n_steps=1, learning_rate=0.01,
model_variables=None, **kwargs):
"""
Performs in-graph SGD steps on the model.
:param initial_model: initial tfnn model (a sack of variables)
:param make_model: a function with no inputs that creates new model
The new model should "live" in the same variable scope as initial_model
:param get_loss: a function(model) -> scalar loss to be optimized
:param n_steps: number of gradient descent steps
:param learning_rate: sgd learning rate
:param model_variables: a list of model variables. defaults to all trainable variables in model scope
"""
model = initial_model
assert hasattr(model, 'scope') or model_variables is not None, \
"model has no .scope, please add it or specify model_variables=[list_of_trainable_weights]"
model_variables = model_variables or tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
scope=model.scope.name)
assert isinstance(model_variables, (list, tuple))
variable_names = [canonicalize(var.name) for var in model_variables]
for step in range(n_steps):
grads = tf.gradients(get_loss(model), model_variables)
# Perform SGD update. Note: if you use adaptive optimizer (e.g. Adam), implement it HERE
updated_variables = {
name: (var - learning_rate * grad) if grad is not None else var
for name, var, grad in zip(variable_names, model_variables, grads)
}
with replace_variables(updated_variables, strict=True, **kwargs):
model = make_model()
model_variables = [updated_variables[name] for name in variable_names]
return model
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment