Skip to content

Instantly share code, notes, and snippets.

@asford
Created February 12, 2018 21:07
Show Gist options
  • Save asford/6f9e2c9d30245dca32f6f533d1cb1f0b to your computer and use it in GitHub Desktop.
Save asford/6f9e2c9d30245dca32f6f533d1cb1f0b to your computer and use it in GitHub Desktop.
workspace/tmol_toy/simple_2d_torch.ipynb
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "import pandas\nfrom jinja2 import Template\nimport traitlets\n\nimport numpy\nimport numpy as np",
"execution_count": 1,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "class SimpleSystem(traitlets.TraitType):\n def validate(self, obj, value):\n value = numpy.array(value, dtype=\"f4\")\n \n assert value.ndim == 2\n assert value.dtype == numpy.dtype(\"f4\")\n \n return value",
"execution_count": 2,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "# System is a n-2 array of coordinates\nclass Sys(traitlets.HasTraits):\n epsilon = 1\n r_m = 1\n sys = SimpleSystem()\n\n @classmethod\n def random(cls, N, extent = None):\n sys = numpy.empty((N, 2), \"f4\")\n if extent is None:\n extent = numpy.sqrt(N) * 4\n\n sys = numpy.random.rand(N,2) * extent\n\n return Sys(sys=sys)\n \n @classmethod\n def lj_potential(cls, r):\n epsilon = cls.epsilon\n r_m = cls.r_m\n \n return epsilon * (( r_m / r ) ** 12 - 2 *( r_m / r ) ** 6)\n\n def dist(self):\n assert self.sys.ndim == 2, self.sys.shape[-1] == 2\n return numpy.linalg.norm(\n self.sys.reshape((-1, 1, 2)) - self.sys.reshape((1, -1, 2)),\n axis=-1)\n \n def pair_potentials(self):\n d = self.dist()\n lj = self.lj_potential(d)\n numpy.fill_diagonal(lj, 0)\n return lj\n \n @property\n def records(self):\n return [\n {\n \"xy\" : list(map(float, self.sys[i][\"xy\"])),\n \"radius\" : float(self.sys[i][\"radius\"])\n }\n for i in range(len(self.sys))\n ]",
"execution_count": 3,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "import torch\nfrom torch.autograd import Variable",
"execution_count": 4,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "class SimpleModel(object):\n def __init__(self, coords):\n \n self.coords = Variable(torch.Tensor(coords), requires_grad=True)\n ind = Variable(torch.Tensor(numpy.arange(self.coords.shape[0])), requires_grad=False)\n\n ind_a = ind.view((-1, 1))\n ind_b = ind.view((1, -1))\n deltas = self.coords.view((-1, 1, 2)) - self.coords.view((1, -1 ,2))\n\n self.dist = torch.norm(deltas, 2, -1)\n\n epsilon = 1.0\n r_m = 1.0\n lj_rep = ( r_m / self.dist ) ** 6\n lj = epsilon * (lj_rep ** 2 - 3 * lj_rep)\n\n self.lj = torch.where(\n ind_a != ind_b,\n lj,\n Variable(torch.Tensor([0.0]), requires_grad=False)\n )\n\n\n self.total_score = torch.sum(self.lj)\n (self.grads,) = torch.autograd.grad(self.total_score, self.coords)",
"execution_count": 11,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "with torch.autograd.profiler.profile() as prof:\n m = SimpleModel(Sys.random(100).sys)\n \nprint(prof)",
"execution_count": 16,
"outputs": [
{
"output_type": "stream",
"text": "------------------------------- --------------- --------------- --------------- --------------- ---------------\nName CPU time CUDA time Calls CPU total CUDA total\n------------------------------- --------------- --------------- --------------- --------------- ---------------\nview 10.193us 0.000us 1 10.193us 0.000us\nview 17.820us 0.000us 1 17.820us 0.000us\nview 10.250us 0.000us 1 10.250us 0.000us\nview 14.276us 0.000us 1 14.276us 0.000us\nexpand 11.610us 0.000us 1 11.610us 0.000us\nexpand 2.403us 0.000us 1 2.403us 0.000us\nsub 144.599us 0.000us 1 144.599us 0.000us\nnorm 297.579us 0.000us 1 297.579us 0.000us\nreciprocal 44.691us 0.000us 1 44.691us 0.000us\nmul 35.426us 0.000us 1 35.426us 0.000us\npow 964.058us 0.000us 1 964.058us 0.000us\npow 22.689us 0.000us 1 22.689us 0.000us\nmul 21.762us 0.000us 1 21.762us 0.000us\nsub 23.807us 0.000us 1 23.807us 0.000us\nmul 33.628us 0.000us 1 33.628us 0.000us\nexpand 3.597us 0.000us 1 3.597us 0.000us\nexpand 1.762us 0.000us 1 1.762us 0.000us\nne 26.023us 0.000us 1 26.023us 0.000us\nwhere 50.091us 0.000us 1 50.091us 0.000us\nexpand 1.878us 0.000us 1 1.878us 0.000us\nexpand 2.926us 0.000us 1 2.926us 0.000us\nexpand 5.359us 0.000us 1 5.359us 0.000us\n_s_where 32.956us 0.000us 1 32.956us 0.000us\nsum 14.403us 0.000us 1 14.403us 0.000us\nones_like 2.776us 0.000us 1 2.776us 0.000us\nN5torch8autograd9GraphRootE 1.186us 0.000us 1 1.186us 0.000us\nSumBackward0 12.248us 0.000us 1 12.248us 0.000us\nexpand 8.321us 0.000us 1 8.321us 0.000us\nSWhereBackward 109.916us 0.000us 1 109.916us 0.000us\nzeros_like 58.031us 0.000us 1 58.031us 0.000us\nwhere 45.463us 0.000us 1 45.463us 0.000us\n_s_where 41.464us 0.000us 1 41.464us 0.000us\nExpandBackward 3.028us 0.000us 1 3.028us 0.000us\nMulBackward0 22.756us 0.000us 1 22.756us 0.000us\nmul 20.839us 0.000us 1 20.839us 0.000us\nSubBackward1 72.671us 0.000us 1 72.671us 0.000us\nneg 22.319us 0.000us 1 22.319us 0.000us\nmul 47.546us 0.000us 1 47.546us 0.000us\nMulBackward0 22.403us 0.000us 1 22.403us 0.000us\nmul 20.859us 0.000us 1 20.859us 0.000us\nPowBackward0 77.689us 0.000us 1 77.689us 0.000us\npow 8.068us 0.000us 1 8.068us 0.000us\nmul 20.557us 0.000us 1 20.557us 0.000us\nmul 41.920us 0.000us 1 41.920us 0.000us\nadd 6.463us 0.000us 1 6.463us 0.000us\nPowBackward0 1046.935us 0.000us 1 1046.935us 0.000us\npow 1024.744us 0.000us 1 1024.744us 0.000us\nmul 6.972us 0.000us 1 6.972us 0.000us\nmul 5.780us 0.000us 1 5.780us 0.000us\nMulBackward0 5.154us 0.000us 1 5.154us 0.000us\nmul 3.802us 0.000us 1 3.802us 0.000us\nReciprocalBackward 27.341us 0.000us 1 27.341us 0.000us\nneg 10.980us 0.000us 1 10.980us 0.000us\nmul 8.409us 0.000us 1 8.409us 0.000us\nmul 4.052us 0.000us 1 4.052us 0.000us\nNormBackward1 146.963us 0.000us 1 146.963us 0.000us\nunsqueeze 3.591us 0.000us 1 3.591us 0.000us\nunsqueeze 1.587us 0.000us 1 1.587us 0.000us\ndiv 9.949us 0.000us 1 9.949us 0.000us\neq 18.006us 0.000us 1 18.006us 0.000us\nmasked_fill_ 11.509us 0.000us 1 11.509us 0.000us\nexpand 2.903us 0.000us 1 2.903us 0.000us\nexpand 2.541us 0.000us 1 2.541us 0.000us\nmul 83.127us 0.000us 1 83.127us 0.000us\nSubBackward1 71.961us 0.000us 1 71.961us 0.000us\nneg 32.730us 0.000us 1 32.730us 0.000us\nmul 35.876us 0.000us 1 35.876us 0.000us\nExpandBackward 38.893us 0.000us 1 38.893us 0.000us\nsum 37.252us 0.000us 1 37.252us 0.000us\nExpandBackward 35.612us 0.000us 1 35.612us 0.000us\nsum 34.519us 0.000us 1 34.519us 0.000us\nViewBackward 7.110us 0.000us 1 7.110us 0.000us\nview 4.902us 0.000us 1 4.902us 0.000us\nViewBackward 3.056us 0.000us 1 3.056us 0.000us\nview 2.001us 0.000us 1 2.001us 0.000us\nadd 2.621us 0.000us 1 2.621us 0.000us\n\n",
"name": "stdout"
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "%%time\nm = SimpleModel(Sys.random(100).sys)",
"execution_count": 17,
"outputs": [
{
"output_type": "stream",
"text": "CPU times: user 17.6 ms, sys: 24 µs, total: 17.6 ms\nWall time: 4.79 ms\n",
"name": "stdout"
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "import scipy.optimize",
"execution_count": 19,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "def min_system(coords):\n coords = coords.reshape((-1, 2))\n \n m = SimpleModel(coords)\n \n return (m.total_score.detach().numpy(), m.grads.numpy().reshape(-1))",
"execution_count": 22,
"outputs": []
},
{
"metadata": {
"trusted": true,
"scrolled": false
},
"cell_type": "code",
"source": "%%time\nscipy.optimize.minimize(min_system, Sys.random(100).sys.reshape(-1), jac=True)",
"execution_count": 31,
"outputs": [
{
"output_type": "stream",
"text": "CPU times: user 30 s, sys: 268 ms, total: 30.2 s\nWall time: 4.32 s\n",
"name": "stdout"
},
{
"data": {
"text/plain": " fun: -587.91357421875\n hess_inv: array([[ 0.42718959, -0.16189395, -0.03315423, ..., 2.89692303,\n -0.13496986, -0.04351769],\n [ -0.16189395, 0.31880333, 0.02395006, ..., -2.09054871,\n 0.09744265, 0.03130786],\n [ -0.03315423, 0.02395006, 0.1488253 , ..., -0.42497672,\n 0.01971232, 0.00625508],\n ..., \n [ 2.89692303, -2.09054871, -0.42497672, ..., 37.56928194,\n -1.72794027, -0.55614717],\n [ -0.13496986, 0.09744265, 0.01971232, ..., -1.72794027,\n 0.33213956, 0.02699726],\n [ -0.04351769, 0.03130786, 0.00625508, ..., -0.55614717,\n 0.02699726, 0.26116469]])\n jac: array([ 0.38523579, -0.0496814 , 2.21046257, 2.43515015, -2.34584236,\n 2.39059591, 0.28824094, -0.81552488, 3.22303438, -1.64476454,\n -0.05001306, 0.58318013, 0.60421205, -1.19142342, -0.15202872,\n -0.28348398, 0.51453841, 0.06684789, -0.01391314, 0.74795753,\n 6.53196716, 0.09004206, 1.79012144, -2.43204999, 0.13654424,\n 0.032967 , 0.2645652 , -1.03697574, -2.4285965 , 0.12757927,\n 0.05970505, 2.48159599, 0.12262318, -0.5121429 , 0.63058674,\n -1.49515843, -0.34411666, -1.743554 , -0.52773875, 1.13063431,\n 0.20909557, 0.75034791, -0.45949841, -0.18476467, 0.73190916,\n 1.09312129, 1.02936995, 1.85753334, -1.59177077, 2.11378765,\n -0.03491788, 0.11962414, 1.28033006, -0.62569731, -0.06100861,\n 0.37443298, -0.25495198, -1.42219579, 0.21304274, -0.90264946,\n -0.51596212, -0.06421641, 0.36131549, 0.28515452, 1.40852678,\n 2.28479719, 0.34235859, 0.29137325, -2.35632777, -0.08954504,\n 0.62333745, -0.11724793, 1.24503684, -0.63282949, 0.12389098,\n -0.49577096, -1.00867546, -0.88279843, 1.30143321, -0.23049027,\n 0.01500795, 0.163496 , 0.06981196, -0.63430339, -1.3899051 ,\n 0.69340312, -3.48437524, 0.95410347, 0.00000005, 0.00000006,\n -1.40927923, 0.5273782 , -2.78665018, 1.69454539, -1.12086105,\n -1.02955019, -0.38461763, -0.79912758, -0.00506348, 0.02306248,\n 0.40349364, -0.30145487, -0.40205553, 1.71672595, -2.12335396,\n -1.22424185, -0.52442694, 0.41939157, -0.20031197, -0.30864641,\n -1.81081009, 0.84658563, 1.05679774, -0.34462315, -0.00035464,\n -0.001526 , 2.28704071, 0.55913359, 0.81812394, 1.09888315,\n 0.67888618, -0.79707116, -1.32535517, -0.60258812, -0.00340491,\n 0.00011403, -2.27690697, 1.03060436, -0.64519113, 1.37708235,\n -3.1734159 , 0.56664658, -0.00010123, 0.00023709, -0.04950345,\n 0.39682764, -0.61484241, 0.74479103, 0.90746552, -1.77414584,\n -0.5030582 , -2.72488976, -0.00001431, -0.00002999, -0.0576289 ,\n 0.35122618, 0.91538554, -0.24858546, 0.13770582, 4.35999298,\n 0.1933488 , 0.6636672 , -0.00009252, 0.0001019 , -0.5088194 ,\n 0.85766089, 0.20883809, -0.0669069 , 2.05816865, 1.72243631,\n 0.43876946, -0.7398349 , -0.39252698, 0.91397893, -0.98589283,\n 0.94066775, -0.277787 , 0.56349981, 1.41959429, -1.45558798,\n -4.24375296, -3.58468914, 3.72360611, -3.45469618, 2.05636072,\n -1.5957011 , -0.31906542, -0.21978247, 0.31507331, -0.51289564,\n 0.45230404, -0.27486092, -0.13667801, -0.03350664, 0.07564032,\n -1.08167934, 0.45960537, 0.1848218 , 0.3147586 , -4.01208591,\n 0.00504809, -0.02291123, 0.04565082, 1.22845423, -0.51357824,\n 0.24923086, -0.1216586 , 0.53507048, -0.75127023, 0.05834781], dtype=float32)\n message: 'Desired error not necessarily achieved due to precision loss.'\n nfev: 1034\n nit: 331\n njev: 992\n status: 2\n success: False\n x: array([ 27.11243698, 1.4658458 , 4.36321494, 34.31644514,\n 3.56746796, 34.80533238, 15.08757529, 3.98548194,\n 33.95724908, 26.99451389, 27.5164167 , 0.62809866,\n 35.89815127, 25.47870292, 30.01163948, 37.92017874,\n 30.11902199, 9.79354823, 33.07979827, 5.72274144,\n 15.45998452, 4.83812808, 2.32754268, 20.41629146,\n 21.15580755, -1.82753174, 8.71526259, 20.69371196,\n 14.90159788, 5.58713918, 20.02174995, 23.95959778,\n 2.3389062 , 24.44064105, 32.90193827, 4.80284838,\n 17.5786659 , 5.96704985, 18.34257426, 16.51376875,\n 26.70351676, 22.09816208, 16.37684312, 37.45870321,\n 26.60727338, 21.17244662, 34.10797283, 33.8894351 ,\n 34.12442048, 26.07821185, 26.10869219, 2.7255913 ,\n 8.83679018, 10.5845847 , 29.05700981, 21.37987062,\n 34.50623694, 33.0401523 , 26.97701308, 2.38493926,\n 31.04477568, 9.91181483, 0.79902271, 14.48846992,\n 20.13557027, 16.18981186, 28.30366842, 21.92580171,\n 16.65826585, 5.80668184, 20.96569675, 16.61674573,\n 13.61368964, 4.6370774 , 37.25808095, 13.99398559,\n 20.08721575, 22.08710205, 5.15175676, 33.81609007,\n 19.26552717, 22.52214449, 33.57818676, 33.12085187,\n 2.27356183, 21.35379398, 14.53446199, 4.73641399,\n 48.63000176, 53.88400288, 27.45507576, 21.54611439,\n 20.05942585, 23.02208748, 26.50109981, 20.24561824,\n 36.37004582, 27.02234843, 20.47387257, 31.53303528,\n 2.77125181, 35.2864493 , 1.49236116, 20.84796848,\n 3.54039866, 33.86975718, 34.63971068, 34.65344299,\n 16.97680916, 6.68030504, 34.82844308, 26.68792511,\n 2.74961923, 34.35414487, 2.89076606, 9.30627712,\n 19.27423218, 16.53416792, 36.58672645, 26.11348705,\n 28.20728051, 20.99665674, 7.90364452, 10.47693058,\n 15.00277426, 22.83957277, 35.69681061, 26.38683939,\n 16.36784001, 4.93715518, 16.00944101, 4.08181052,\n 27.37062655, 29.03844568, 1.65674295, 14.85170194,\n 32.19607035, 5.41992065, 27.35859621, 20.61683071,\n 4.32805194, 33.37533603, 0.33588172, 2.13394343,\n 29.93656695, 38.8528303 , 8.49814421, 21.59963938,\n 14.16466013, 3.88628567, 37.54891725, 13.10899644,\n 10.56061471, 39.25564814, 9.39300952, 21.34081819,\n 29.16796668, 38.32116816, 17.27877465, 5.09138245,\n 1.53902134, 13.92001581, 36.63212265, 13.30253665,\n 9.60965823, 20.43661657, 27.55347813, 22.47469187,\n 35.00691551, 25.76564136, 18.82749393, 15.70942162,\n 19.24010267, 23.44955109, 16.9442085 , 4.2109619 ,\n 35.0364803 , 33.8069337 , 9.1730251 , 22.24440229,\n 17.89661794, 15.69361059, 20.24727647, -2.04858337,\n 36.9250777 , 12.41441391, 17.24529081, 37.80786381,\n 14.71937104, 3.12747966, 20.67474477, 30.62015574,\n 8.46353092, 9.73024248, 26.24296909, 1.80232025,\n 2.13012898, 25.35321359, 2.39619315, 14.28481178])"
},
"execution_count": 31,
"output_type": "execute_result",
"metadata": {}
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "",
"execution_count": null,
"outputs": []
}
],
"metadata": {
"kernelspec": {
"name": "conda-env-tmol_toy-py",
"display_name": "Python [conda env:tmol_toy]",
"language": "python"
},
"language_info": {
"nbconvert_exporter": "python",
"name": "python",
"file_extension": ".py",
"codemirror_mode": {
"version": 3,
"name": "ipython"
},
"pygments_lexer": "ipython3",
"version": "3.5.4",
"mimetype": "text/x-python"
},
"toc": {
"nav_menu": {},
"number_sections": true,
"sideBar": true,
"skip_h1_title": false,
"toc_cell": false,
"toc_position": {},
"toc_section_display": "block",
"toc_window_display": false
},
"gist": {
"id": "",
"data": {
"description": "workspace/tmol_toy/simple_2d_torch.ipynb",
"public": true
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment