Created
February 12, 2018 21:07
-
-
Save asford/6f9e2c9d30245dca32f6f533d1cb1f0b to your computer and use it in GitHub Desktop.
workspace/tmol_toy/simple_2d_torch.ipynb
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "import pandas\nfrom jinja2 import Template\nimport traitlets\n\nimport numpy\nimport numpy as np", | |
"execution_count": 1, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "class SimpleSystem(traitlets.TraitType):\n def validate(self, obj, value):\n value = numpy.array(value, dtype=\"f4\")\n \n assert value.ndim == 2\n assert value.dtype == numpy.dtype(\"f4\")\n \n return value", | |
"execution_count": 2, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "# System is a n-2 array of coordinates\nclass Sys(traitlets.HasTraits):\n epsilon = 1\n r_m = 1\n sys = SimpleSystem()\n\n @classmethod\n def random(cls, N, extent = None):\n sys = numpy.empty((N, 2), \"f4\")\n if extent is None:\n extent = numpy.sqrt(N) * 4\n\n sys = numpy.random.rand(N,2) * extent\n\n return Sys(sys=sys)\n \n @classmethod\n def lj_potential(cls, r):\n epsilon = cls.epsilon\n r_m = cls.r_m\n \n return epsilon * (( r_m / r ) ** 12 - 2 *( r_m / r ) ** 6)\n\n def dist(self):\n assert self.sys.ndim == 2, self.sys.shape[-1] == 2\n return numpy.linalg.norm(\n self.sys.reshape((-1, 1, 2)) - self.sys.reshape((1, -1, 2)),\n axis=-1)\n \n def pair_potentials(self):\n d = self.dist()\n lj = self.lj_potential(d)\n numpy.fill_diagonal(lj, 0)\n return lj\n \n @property\n def records(self):\n return [\n {\n \"xy\" : list(map(float, self.sys[i][\"xy\"])),\n \"radius\" : float(self.sys[i][\"radius\"])\n }\n for i in range(len(self.sys))\n ]", | |
"execution_count": 3, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "import torch\nfrom torch.autograd import Variable", | |
"execution_count": 4, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "class SimpleModel(object):\n def __init__(self, coords):\n \n self.coords = Variable(torch.Tensor(coords), requires_grad=True)\n ind = Variable(torch.Tensor(numpy.arange(self.coords.shape[0])), requires_grad=False)\n\n ind_a = ind.view((-1, 1))\n ind_b = ind.view((1, -1))\n deltas = self.coords.view((-1, 1, 2)) - self.coords.view((1, -1 ,2))\n\n self.dist = torch.norm(deltas, 2, -1)\n\n epsilon = 1.0\n r_m = 1.0\n lj_rep = ( r_m / self.dist ) ** 6\n lj = epsilon * (lj_rep ** 2 - 3 * lj_rep)\n\n self.lj = torch.where(\n ind_a != ind_b,\n lj,\n Variable(torch.Tensor([0.0]), requires_grad=False)\n )\n\n\n self.total_score = torch.sum(self.lj)\n (self.grads,) = torch.autograd.grad(self.total_score, self.coords)", | |
"execution_count": 11, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "with torch.autograd.profiler.profile() as prof:\n m = SimpleModel(Sys.random(100).sys)\n \nprint(prof)", | |
"execution_count": 16, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": "------------------------------- --------------- --------------- --------------- --------------- ---------------\nName CPU time CUDA time Calls CPU total CUDA total\n------------------------------- --------------- --------------- --------------- --------------- ---------------\nview 10.193us 0.000us 1 10.193us 0.000us\nview 17.820us 0.000us 1 17.820us 0.000us\nview 10.250us 0.000us 1 10.250us 0.000us\nview 14.276us 0.000us 1 14.276us 0.000us\nexpand 11.610us 0.000us 1 11.610us 0.000us\nexpand 2.403us 0.000us 1 2.403us 0.000us\nsub 144.599us 0.000us 1 144.599us 0.000us\nnorm 297.579us 0.000us 1 297.579us 0.000us\nreciprocal 44.691us 0.000us 1 44.691us 0.000us\nmul 35.426us 0.000us 1 35.426us 0.000us\npow 964.058us 0.000us 1 964.058us 0.000us\npow 22.689us 0.000us 1 22.689us 0.000us\nmul 21.762us 0.000us 1 21.762us 0.000us\nsub 23.807us 0.000us 1 23.807us 0.000us\nmul 33.628us 0.000us 1 33.628us 0.000us\nexpand 3.597us 0.000us 1 3.597us 0.000us\nexpand 1.762us 0.000us 1 1.762us 0.000us\nne 26.023us 0.000us 1 26.023us 0.000us\nwhere 50.091us 0.000us 1 50.091us 0.000us\nexpand 1.878us 0.000us 1 1.878us 0.000us\nexpand 2.926us 0.000us 1 2.926us 0.000us\nexpand 5.359us 0.000us 1 5.359us 0.000us\n_s_where 32.956us 0.000us 1 32.956us 0.000us\nsum 14.403us 0.000us 1 14.403us 0.000us\nones_like 2.776us 0.000us 1 2.776us 0.000us\nN5torch8autograd9GraphRootE 1.186us 0.000us 1 1.186us 0.000us\nSumBackward0 12.248us 0.000us 1 12.248us 0.000us\nexpand 8.321us 0.000us 1 8.321us 0.000us\nSWhereBackward 109.916us 0.000us 1 109.916us 0.000us\nzeros_like 58.031us 0.000us 1 58.031us 0.000us\nwhere 45.463us 0.000us 1 45.463us 0.000us\n_s_where 41.464us 0.000us 1 41.464us 0.000us\nExpandBackward 3.028us 0.000us 1 3.028us 0.000us\nMulBackward0 22.756us 0.000us 1 22.756us 0.000us\nmul 20.839us 0.000us 1 20.839us 0.000us\nSubBackward1 72.671us 0.000us 1 72.671us 0.000us\nneg 22.319us 0.000us 1 22.319us 0.000us\nmul 47.546us 0.000us 1 47.546us 0.000us\nMulBackward0 22.403us 0.000us 1 22.403us 0.000us\nmul 20.859us 0.000us 1 20.859us 0.000us\nPowBackward0 77.689us 0.000us 1 77.689us 0.000us\npow 8.068us 0.000us 1 8.068us 0.000us\nmul 20.557us 0.000us 1 20.557us 0.000us\nmul 41.920us 0.000us 1 41.920us 0.000us\nadd 6.463us 0.000us 1 6.463us 0.000us\nPowBackward0 1046.935us 0.000us 1 1046.935us 0.000us\npow 1024.744us 0.000us 1 1024.744us 0.000us\nmul 6.972us 0.000us 1 6.972us 0.000us\nmul 5.780us 0.000us 1 5.780us 0.000us\nMulBackward0 5.154us 0.000us 1 5.154us 0.000us\nmul 3.802us 0.000us 1 3.802us 0.000us\nReciprocalBackward 27.341us 0.000us 1 27.341us 0.000us\nneg 10.980us 0.000us 1 10.980us 0.000us\nmul 8.409us 0.000us 1 8.409us 0.000us\nmul 4.052us 0.000us 1 4.052us 0.000us\nNormBackward1 146.963us 0.000us 1 146.963us 0.000us\nunsqueeze 3.591us 0.000us 1 3.591us 0.000us\nunsqueeze 1.587us 0.000us 1 1.587us 0.000us\ndiv 9.949us 0.000us 1 9.949us 0.000us\neq 18.006us 0.000us 1 18.006us 0.000us\nmasked_fill_ 11.509us 0.000us 1 11.509us 0.000us\nexpand 2.903us 0.000us 1 2.903us 0.000us\nexpand 2.541us 0.000us 1 2.541us 0.000us\nmul 83.127us 0.000us 1 83.127us 0.000us\nSubBackward1 71.961us 0.000us 1 71.961us 0.000us\nneg 32.730us 0.000us 1 32.730us 0.000us\nmul 35.876us 0.000us 1 35.876us 0.000us\nExpandBackward 38.893us 0.000us 1 38.893us 0.000us\nsum 37.252us 0.000us 1 37.252us 0.000us\nExpandBackward 35.612us 0.000us 1 35.612us 0.000us\nsum 34.519us 0.000us 1 34.519us 0.000us\nViewBackward 7.110us 0.000us 1 7.110us 0.000us\nview 4.902us 0.000us 1 4.902us 0.000us\nViewBackward 3.056us 0.000us 1 3.056us 0.000us\nview 2.001us 0.000us 1 2.001us 0.000us\nadd 2.621us 0.000us 1 2.621us 0.000us\n\n", | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "%%time\nm = SimpleModel(Sys.random(100).sys)", | |
"execution_count": 17, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": "CPU times: user 17.6 ms, sys: 24 µs, total: 17.6 ms\nWall time: 4.79 ms\n", | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "import scipy.optimize", | |
"execution_count": 19, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "def min_system(coords):\n coords = coords.reshape((-1, 2))\n \n m = SimpleModel(coords)\n \n return (m.total_score.detach().numpy(), m.grads.numpy().reshape(-1))", | |
"execution_count": 22, | |
"outputs": [] | |
}, | |
{ | |
"metadata": { | |
"trusted": true, | |
"scrolled": false | |
}, | |
"cell_type": "code", | |
"source": "%%time\nscipy.optimize.minimize(min_system, Sys.random(100).sys.reshape(-1), jac=True)", | |
"execution_count": 31, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": "CPU times: user 30 s, sys: 268 ms, total: 30.2 s\nWall time: 4.32 s\n", | |
"name": "stdout" | |
}, | |
{ | |
"data": { | |
"text/plain": " fun: -587.91357421875\n hess_inv: array([[ 0.42718959, -0.16189395, -0.03315423, ..., 2.89692303,\n -0.13496986, -0.04351769],\n [ -0.16189395, 0.31880333, 0.02395006, ..., -2.09054871,\n 0.09744265, 0.03130786],\n [ -0.03315423, 0.02395006, 0.1488253 , ..., -0.42497672,\n 0.01971232, 0.00625508],\n ..., \n [ 2.89692303, -2.09054871, -0.42497672, ..., 37.56928194,\n -1.72794027, -0.55614717],\n [ -0.13496986, 0.09744265, 0.01971232, ..., -1.72794027,\n 0.33213956, 0.02699726],\n [ -0.04351769, 0.03130786, 0.00625508, ..., -0.55614717,\n 0.02699726, 0.26116469]])\n jac: array([ 0.38523579, -0.0496814 , 2.21046257, 2.43515015, -2.34584236,\n 2.39059591, 0.28824094, -0.81552488, 3.22303438, -1.64476454,\n -0.05001306, 0.58318013, 0.60421205, -1.19142342, -0.15202872,\n -0.28348398, 0.51453841, 0.06684789, -0.01391314, 0.74795753,\n 6.53196716, 0.09004206, 1.79012144, -2.43204999, 0.13654424,\n 0.032967 , 0.2645652 , -1.03697574, -2.4285965 , 0.12757927,\n 0.05970505, 2.48159599, 0.12262318, -0.5121429 , 0.63058674,\n -1.49515843, -0.34411666, -1.743554 , -0.52773875, 1.13063431,\n 0.20909557, 0.75034791, -0.45949841, -0.18476467, 0.73190916,\n 1.09312129, 1.02936995, 1.85753334, -1.59177077, 2.11378765,\n -0.03491788, 0.11962414, 1.28033006, -0.62569731, -0.06100861,\n 0.37443298, -0.25495198, -1.42219579, 0.21304274, -0.90264946,\n -0.51596212, -0.06421641, 0.36131549, 0.28515452, 1.40852678,\n 2.28479719, 0.34235859, 0.29137325, -2.35632777, -0.08954504,\n 0.62333745, -0.11724793, 1.24503684, -0.63282949, 0.12389098,\n -0.49577096, -1.00867546, -0.88279843, 1.30143321, -0.23049027,\n 0.01500795, 0.163496 , 0.06981196, -0.63430339, -1.3899051 ,\n 0.69340312, -3.48437524, 0.95410347, 0.00000005, 0.00000006,\n -1.40927923, 0.5273782 , -2.78665018, 1.69454539, -1.12086105,\n -1.02955019, -0.38461763, -0.79912758, -0.00506348, 0.02306248,\n 0.40349364, -0.30145487, -0.40205553, 1.71672595, -2.12335396,\n -1.22424185, -0.52442694, 0.41939157, -0.20031197, -0.30864641,\n -1.81081009, 0.84658563, 1.05679774, -0.34462315, -0.00035464,\n -0.001526 , 2.28704071, 0.55913359, 0.81812394, 1.09888315,\n 0.67888618, -0.79707116, -1.32535517, -0.60258812, -0.00340491,\n 0.00011403, -2.27690697, 1.03060436, -0.64519113, 1.37708235,\n -3.1734159 , 0.56664658, -0.00010123, 0.00023709, -0.04950345,\n 0.39682764, -0.61484241, 0.74479103, 0.90746552, -1.77414584,\n -0.5030582 , -2.72488976, -0.00001431, -0.00002999, -0.0576289 ,\n 0.35122618, 0.91538554, -0.24858546, 0.13770582, 4.35999298,\n 0.1933488 , 0.6636672 , -0.00009252, 0.0001019 , -0.5088194 ,\n 0.85766089, 0.20883809, -0.0669069 , 2.05816865, 1.72243631,\n 0.43876946, -0.7398349 , -0.39252698, 0.91397893, -0.98589283,\n 0.94066775, -0.277787 , 0.56349981, 1.41959429, -1.45558798,\n -4.24375296, -3.58468914, 3.72360611, -3.45469618, 2.05636072,\n -1.5957011 , -0.31906542, -0.21978247, 0.31507331, -0.51289564,\n 0.45230404, -0.27486092, -0.13667801, -0.03350664, 0.07564032,\n -1.08167934, 0.45960537, 0.1848218 , 0.3147586 , -4.01208591,\n 0.00504809, -0.02291123, 0.04565082, 1.22845423, -0.51357824,\n 0.24923086, -0.1216586 , 0.53507048, -0.75127023, 0.05834781], dtype=float32)\n message: 'Desired error not necessarily achieved due to precision loss.'\n nfev: 1034\n nit: 331\n njev: 992\n status: 2\n success: False\n x: array([ 27.11243698, 1.4658458 , 4.36321494, 34.31644514,\n 3.56746796, 34.80533238, 15.08757529, 3.98548194,\n 33.95724908, 26.99451389, 27.5164167 , 0.62809866,\n 35.89815127, 25.47870292, 30.01163948, 37.92017874,\n 30.11902199, 9.79354823, 33.07979827, 5.72274144,\n 15.45998452, 4.83812808, 2.32754268, 20.41629146,\n 21.15580755, -1.82753174, 8.71526259, 20.69371196,\n 14.90159788, 5.58713918, 20.02174995, 23.95959778,\n 2.3389062 , 24.44064105, 32.90193827, 4.80284838,\n 17.5786659 , 5.96704985, 18.34257426, 16.51376875,\n 26.70351676, 22.09816208, 16.37684312, 37.45870321,\n 26.60727338, 21.17244662, 34.10797283, 33.8894351 ,\n 34.12442048, 26.07821185, 26.10869219, 2.7255913 ,\n 8.83679018, 10.5845847 , 29.05700981, 21.37987062,\n 34.50623694, 33.0401523 , 26.97701308, 2.38493926,\n 31.04477568, 9.91181483, 0.79902271, 14.48846992,\n 20.13557027, 16.18981186, 28.30366842, 21.92580171,\n 16.65826585, 5.80668184, 20.96569675, 16.61674573,\n 13.61368964, 4.6370774 , 37.25808095, 13.99398559,\n 20.08721575, 22.08710205, 5.15175676, 33.81609007,\n 19.26552717, 22.52214449, 33.57818676, 33.12085187,\n 2.27356183, 21.35379398, 14.53446199, 4.73641399,\n 48.63000176, 53.88400288, 27.45507576, 21.54611439,\n 20.05942585, 23.02208748, 26.50109981, 20.24561824,\n 36.37004582, 27.02234843, 20.47387257, 31.53303528,\n 2.77125181, 35.2864493 , 1.49236116, 20.84796848,\n 3.54039866, 33.86975718, 34.63971068, 34.65344299,\n 16.97680916, 6.68030504, 34.82844308, 26.68792511,\n 2.74961923, 34.35414487, 2.89076606, 9.30627712,\n 19.27423218, 16.53416792, 36.58672645, 26.11348705,\n 28.20728051, 20.99665674, 7.90364452, 10.47693058,\n 15.00277426, 22.83957277, 35.69681061, 26.38683939,\n 16.36784001, 4.93715518, 16.00944101, 4.08181052,\n 27.37062655, 29.03844568, 1.65674295, 14.85170194,\n 32.19607035, 5.41992065, 27.35859621, 20.61683071,\n 4.32805194, 33.37533603, 0.33588172, 2.13394343,\n 29.93656695, 38.8528303 , 8.49814421, 21.59963938,\n 14.16466013, 3.88628567, 37.54891725, 13.10899644,\n 10.56061471, 39.25564814, 9.39300952, 21.34081819,\n 29.16796668, 38.32116816, 17.27877465, 5.09138245,\n 1.53902134, 13.92001581, 36.63212265, 13.30253665,\n 9.60965823, 20.43661657, 27.55347813, 22.47469187,\n 35.00691551, 25.76564136, 18.82749393, 15.70942162,\n 19.24010267, 23.44955109, 16.9442085 , 4.2109619 ,\n 35.0364803 , 33.8069337 , 9.1730251 , 22.24440229,\n 17.89661794, 15.69361059, 20.24727647, -2.04858337,\n 36.9250777 , 12.41441391, 17.24529081, 37.80786381,\n 14.71937104, 3.12747966, 20.67474477, 30.62015574,\n 8.46353092, 9.73024248, 26.24296909, 1.80232025,\n 2.13012898, 25.35321359, 2.39619315, 14.28481178])" | |
}, | |
"execution_count": 31, | |
"output_type": "execute_result", | |
"metadata": {} | |
} | |
] | |
}, | |
{ | |
"metadata": { | |
"trusted": true | |
}, | |
"cell_type": "code", | |
"source": "", | |
"execution_count": null, | |
"outputs": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"name": "conda-env-tmol_toy-py", | |
"display_name": "Python [conda env:tmol_toy]", | |
"language": "python" | |
}, | |
"language_info": { | |
"nbconvert_exporter": "python", | |
"name": "python", | |
"file_extension": ".py", | |
"codemirror_mode": { | |
"version": 3, | |
"name": "ipython" | |
}, | |
"pygments_lexer": "ipython3", | |
"version": "3.5.4", | |
"mimetype": "text/x-python" | |
}, | |
"toc": { | |
"nav_menu": {}, | |
"number_sections": true, | |
"sideBar": true, | |
"skip_h1_title": false, | |
"toc_cell": false, | |
"toc_position": {}, | |
"toc_section_display": "block", | |
"toc_window_display": false | |
}, | |
"gist": { | |
"id": "", | |
"data": { | |
"description": "workspace/tmol_toy/simple_2d_torch.ipynb", | |
"public": true | |
} | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment