Skip to content

Instantly share code, notes, and snippets.

@asford
Created February 13, 2018 00:11
Show Gist options
  • Save asford/95b1363471dc14b57f50c57635ceb80d to your computer and use it in GitHub Desktop.
Save asford/95b1363471dc14b57f50c57635ceb80d to your computer and use it in GitHub Desktop.
workspace/tmol_toy/simple_2d_torch.ipynb
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "import pandas\nfrom jinja2 import Template\nimport traitlets\n\nimport numpy\nimport numpy as np",
"execution_count": 1,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "class SimpleSystem(traitlets.TraitType):\n def validate(self, obj, value):\n value = numpy.array(value, dtype=\"f4\")\n \n assert value.ndim == 2\n assert value.dtype == numpy.dtype(\"f4\")\n \n return value",
"execution_count": 2,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "# System is a n-2 array of coordinates\nclass Sys(traitlets.HasTraits):\n epsilon = 1\n r_m = 1\n sys = SimpleSystem()\n\n @classmethod\n def random(cls, N, extent = None):\n sys = numpy.empty((N, 2), \"f4\")\n if extent is None:\n extent = numpy.sqrt(N) * 4\n\n sys = numpy.random.rand(N,2) * extent\n\n return Sys(sys=sys)\n \n @classmethod\n def lj_potential(cls, r):\n epsilon = cls.epsilon\n r_m = cls.r_m\n \n return epsilon * (( r_m / r ) ** 12 - 2 *( r_m / r ) ** 6)\n\n def dist(self):\n assert self.sys.ndim == 2, self.sys.shape[-1] == 2\n return numpy.linalg.norm(\n self.sys.reshape((-1, 1, 2)) - self.sys.reshape((1, -1, 2)),\n axis=-1)\n \n def pair_potentials(self):\n d = self.dist()\n lj = self.lj_potential(d)\n numpy.fill_diagonal(lj, 0)\n return lj\n \n @property\n def records(self):\n return [\n {\n \"xy\" : list(map(float, self.sys[i][\"xy\"])),\n \"radius\" : float(self.sys[i][\"radius\"])\n }\n for i in range(len(self.sys))\n ]",
"execution_count": 3,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "import torch\nfrom torch.autograd import Variable",
"execution_count": 4,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "class SimpleModel(object):\n def __init__(self, coords):\n \n self.coords = Variable(torch.Tensor(coords), requires_grad=True)\n ind = Variable(torch.Tensor(numpy.arange(self.coords.shape[0])), requires_grad=False)\n\n ind_a = ind.view((-1, 1))\n ind_b = ind.view((1, -1))\n deltas = self.coords.view((-1, 1, 2)) - self.coords.view((1, -1 ,2))\n\n self.dist = torch.norm(deltas, 2, -1)\n\n epsilon = 1.0\n r_m = 1.0\n lj_rep = ( r_m / self.dist ) ** 6\n lj = epsilon * (lj_rep ** 2 - 3 * lj_rep)\n\n self.lj = torch.where(\n ind_a != ind_b,\n lj,\n Variable(torch.Tensor([0.0]), requires_grad=False)\n )\n\n\n self.total_score = torch.sum(self.lj)\n (self.grads,) = torch.autograd.grad(self.total_score, self.coords)\n \nclass SimpleMultModel(object):\n def __init__(self, coords):\n \n self.coords = Variable(torch.Tensor(coords), requires_grad=True)\n ind = Variable(torch.Tensor(numpy.arange(self.coords.shape[0])), requires_grad=False)\n\n ind_a = ind.view((-1, 1))\n ind_b = ind.view((1, -1))\n deltas = self.coords.view((-1, 1, 2)) - self.coords.view((1, -1 ,2))\n\n self.dist = torch.norm(deltas, 2, -1)\n\n epsilon = 1.0\n r_m = 1.0\n fd = (r_m / self.dist)\n fd2 = fd * fd\n fd6 = fd2 * fd2 * fd2\n fd12 = fd6 * fd6\n lj = epsilon * (fd12 - 3 * fd6)\n\n self.lj = torch.where(\n ind_a != ind_b,\n lj,\n Variable(torch.Tensor([0.0]), requires_grad=False)\n )\n\n\n self.total_score = torch.sum(self.lj)\n (self.grads,) = torch.autograd.grad(self.total_score, self.coords)",
"execution_count": 32,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "with torch.autograd.profiler.profile() as prof:\n m = SimpleModel(Sys.random(100).sys)\n \nprint(prof)",
"execution_count": 33,
"outputs": [
{
"output_type": "stream",
"text": "------------------------------- --------------- --------------- --------------- --------------- ---------------\nName CPU time CUDA time Calls CPU total CUDA total\n------------------------------- --------------- --------------- --------------- --------------- ---------------\nview 12.286us 0.000us 1 12.286us 0.000us\nview 4.773us 0.000us 1 4.773us 0.000us\nview 12.572us 0.000us 1 12.572us 0.000us\nview 4.169us 0.000us 1 4.169us 0.000us\nexpand 9.630us 0.000us 1 9.630us 0.000us\nexpand 3.072us 0.000us 1 3.072us 0.000us\nsub 1218.475us 0.000us 1 1218.475us 0.000us\nnorm 278.484us 0.000us 1 278.484us 0.000us\nreciprocal 21.029us 0.000us 1 21.029us 0.000us\nmul 19.757us 0.000us 1 19.757us 0.000us\npow 956.707us 0.000us 1 956.707us 0.000us\npow 12.356us 0.000us 1 12.356us 0.000us\nmul 11.034us 0.000us 1 11.034us 0.000us\nsub 26.950us 0.000us 1 26.950us 0.000us\nmul 9.967us 0.000us 1 9.967us 0.000us\nexpand 5.112us 0.000us 1 5.112us 0.000us\nexpand 2.036us 0.000us 1 2.036us 0.000us\nne 29.089us 0.000us 1 29.089us 0.000us\nwhere 55.909us 0.000us 1 55.909us 0.000us\nexpand 2.186us 0.000us 1 2.186us 0.000us\nexpand 3.807us 0.000us 1 3.807us 0.000us\nexpand 4.793us 0.000us 1 4.793us 0.000us\n_s_where 37.061us 0.000us 1 37.061us 0.000us\nsum 17.036us 0.000us 1 17.036us 0.000us\nones_like 3.913us 0.000us 1 3.913us 0.000us\nN5torch8autograd9GraphRootE 1.393us 0.000us 1 1.393us 0.000us\nSumBackward0 12.268us 0.000us 1 12.268us 0.000us\nexpand 7.357us 0.000us 1 7.357us 0.000us\nSWhereBackward 58.375us 0.000us 1 58.375us 0.000us\nzeros_like 18.953us 0.000us 1 18.953us 0.000us\nwhere 32.890us 0.000us 1 32.890us 0.000us\n_s_where 28.629us 0.000us 1 28.629us 0.000us\nExpandBackward 1.432us 0.000us 1 1.432us 0.000us\nMulBackward0 13.034us 0.000us 1 13.034us 0.000us\nmul 11.340us 0.000us 1 11.340us 0.000us\nSubBackward1 61.750us 0.000us 1 61.750us 0.000us\nneg 32.470us 0.000us 1 32.470us 0.000us\nmul 26.080us 0.000us 1 26.080us 0.000us\nMulBackward0 24.816us 0.000us 1 24.816us 0.000us\nmul 23.874us 0.000us 1 23.874us 0.000us\nPowBackward0 65.693us 0.000us 1 65.693us 0.000us\npow 9.375us 0.000us 1 9.375us 0.000us\nmul 24.109us 0.000us 1 24.109us 0.000us\nmul 24.190us 0.000us 1 24.190us 0.000us\nadd 5.282us 0.000us 1 5.282us 0.000us\nPowBackward0 966.924us 0.000us 1 966.924us 0.000us\npow 955.551us 0.000us 1 955.551us 0.000us\nmul 3.644us 0.000us 1 3.644us 0.000us\nmul 4.235us 0.000us 1 4.235us 0.000us\nMulBackward0 7.645us 0.000us 1 7.645us 0.000us\nmul 6.297us 0.000us 1 6.297us 0.000us\nReciprocalBackward 26.247us 0.000us 1 26.247us 0.000us\nneg 9.818us 0.000us 1 9.818us 0.000us\nmul 8.210us 0.000us 1 8.210us 0.000us\nmul 3.741us 0.000us 1 3.741us 0.000us\nNormBackward1 219.304us 0.000us 1 219.304us 0.000us\nunsqueeze 3.477us 0.000us 1 3.477us 0.000us\nunsqueeze 1.550us 0.000us 1 1.550us 0.000us\ndiv 10.215us 0.000us 1 10.215us 0.000us\neq 15.441us 0.000us 1 15.441us 0.000us\nmasked_fill_ 11.944us 0.000us 1 11.944us 0.000us\nexpand 2.356us 0.000us 1 2.356us 0.000us\nexpand 2.102us 0.000us 1 2.102us 0.000us\nmul 158.224us 0.000us 1 158.224us 0.000us\nSubBackward1 32.642us 0.000us 1 32.642us 0.000us\nneg 22.584us 0.000us 1 22.584us 0.000us\nmul 7.541us 0.000us 1 7.541us 0.000us\nExpandBackward 39.707us 0.000us 1 39.707us 0.000us\nsum 37.537us 0.000us 1 37.537us 0.000us\nExpandBackward 38.935us 0.000us 1 38.935us 0.000us\nsum 33.724us 0.000us 1 33.724us 0.000us\nViewBackward 9.774us 0.000us 1 9.774us 0.000us\nview 7.715us 0.000us 1 7.715us 0.000us\nViewBackward 3.357us 0.000us 1 3.357us 0.000us\nview 2.254us 0.000us 1 2.254us 0.000us\nadd 2.669us 0.000us 1 2.669us 0.000us\n\n",
"name": "stdout"
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "with torch.autograd.profiler.profile() as prof:\n m = SimpleMultModel(Sys.random(100).sys)\n \nprint(prof)",
"execution_count": 44,
"outputs": [
{
"output_type": "stream",
"text": "------------------------------- --------------- --------------- --------------- --------------- ---------------\nName CPU time CUDA time Calls CPU total CUDA total\n------------------------------- --------------- --------------- --------------- --------------- ---------------\nview 11.198us 0.000us 1 11.198us 0.000us\nview 3.589us 0.000us 1 3.589us 0.000us\nview 10.994us 0.000us 1 10.994us 0.000us\nview 4.824us 0.000us 1 4.824us 0.000us\nexpand 6.815us 0.000us 1 6.815us 0.000us\nexpand 2.636us 0.000us 1 2.636us 0.000us\nsub 113.581us 0.000us 1 113.581us 0.000us\nnorm 398.628us 0.000us 1 398.628us 0.000us\nreciprocal 18.486us 0.000us 1 18.486us 0.000us\nmul 19.383us 0.000us 1 19.383us 0.000us\nmul 11.460us 0.000us 1 11.460us 0.000us\nmul 10.160us 0.000us 1 10.160us 0.000us\nmul 10.297us 0.000us 1 10.297us 0.000us\nmul 8.619us 0.000us 1 8.619us 0.000us\nmul 9.609us 0.000us 1 9.609us 0.000us\nsub 9.108us 0.000us 1 9.108us 0.000us\nmul 9.063us 0.000us 1 9.063us 0.000us\nexpand 13.856us 0.000us 1 13.856us 0.000us\nexpand 1.873us 0.000us 1 1.873us 0.000us\nne 25.929us 0.000us 1 25.929us 0.000us\nwhere 48.671us 0.000us 1 48.671us 0.000us\nexpand 2.218us 0.000us 1 2.218us 0.000us\nexpand 2.908us 0.000us 1 2.908us 0.000us\nexpand 3.861us 0.000us 1 3.861us 0.000us\n_s_where 33.477us 0.000us 1 33.477us 0.000us\nsum 21.082us 0.000us 1 21.082us 0.000us\nones_like 2.581us 0.000us 1 2.581us 0.000us\nN5torch8autograd9GraphRootE 1.528us 0.000us 1 1.528us 0.000us\nSumBackward0 15.085us 0.000us 1 15.085us 0.000us\nexpand 10.550us 0.000us 1 10.550us 0.000us\nSWhereBackward 59.815us 0.000us 1 59.815us 0.000us\nzeros_like 16.724us 0.000us 1 16.724us 0.000us\nwhere 35.826us 0.000us 1 35.826us 0.000us\n_s_where 31.745us 0.000us 1 31.745us 0.000us\nExpandBackward 1.306us 0.000us 1 1.306us 0.000us\nMulBackward0 8.539us 0.000us 1 8.539us 0.000us\nmul 6.430us 0.000us 1 6.430us 0.000us\nSubBackward1 28.042us 0.000us 1 28.042us 0.000us\nneg 15.416us 0.000us 1 15.416us 0.000us\nmul 9.550us 0.000us 1 9.550us 0.000us\nMulBackward0 10.554us 0.000us 1 10.554us 0.000us\nmul 9.407us 0.000us 1 9.407us 0.000us\nMulBackward1 42.761us 0.000us 1 42.761us 0.000us\nmul 10.584us 0.000us 1 10.584us 0.000us\nmul 27.012us 0.000us 1 27.012us 0.000us\nadd 7.413us 0.000us 1 7.413us 0.000us\nadd 5.830us 0.000us 1 5.830us 0.000us\nMulBackward1 21.530us 0.000us 1 21.530us 0.000us\nmul 8.937us 0.000us 1 8.937us 0.000us\nmul 9.430us 0.000us 1 9.430us 0.000us\nMulBackward1 11.803us 0.000us 1 11.803us 0.000us\nmul 4.568us 0.000us 1 4.568us 0.000us\nmul 4.662us 0.000us 1 4.662us 0.000us\nadd 3.726us 0.000us 1 3.726us 0.000us\nadd 3.783us 0.000us 1 3.783us 0.000us\nMulBackward1 15.379us 0.000us 1 15.379us 0.000us\nmul 8.232us 0.000us 1 8.232us 0.000us\nmul 4.420us 0.000us 1 4.420us 0.000us\nadd 3.752us 0.000us 1 3.752us 0.000us\nMulBackward0 5.666us 0.000us 1 5.666us 0.000us\nmul 4.462us 0.000us 1 4.462us 0.000us\nReciprocalBackward 29.980us 0.000us 1 29.980us 0.000us\nneg 13.608us 0.000us 1 13.608us 0.000us\nmul 7.665us 0.000us 1 7.665us 0.000us\nmul 4.270us 0.000us 1 4.270us 0.000us\nNormBackward1 198.634us 0.000us 1 198.634us 0.000us\nunsqueeze 3.480us 0.000us 1 3.480us 0.000us\nunsqueeze 2.011us 0.000us 1 2.011us 0.000us\ndiv 11.834us 0.000us 1 11.834us 0.000us\neq 17.624us 0.000us 1 17.624us 0.000us\nmasked_fill_ 15.309us 0.000us 1 15.309us 0.000us\nexpand 2.989us 0.000us 1 2.989us 0.000us\nexpand 1.652us 0.000us 1 1.652us 0.000us\nmul 128.557us 0.000us 1 128.557us 0.000us\nSubBackward1 153.423us 0.000us 1 153.423us 0.000us\nneg 55.336us 0.000us 1 55.336us 0.000us\nmul 95.074us 0.000us 1 95.074us 0.000us\nExpandBackward 43.312us 0.000us 1 43.312us 0.000us\nsum 41.179us 0.000us 1 41.179us 0.000us\nExpandBackward 44.301us 0.000us 1 44.301us 0.000us\nsum 42.576us 0.000us 1 42.576us 0.000us\nViewBackward 8.418us 0.000us 1 8.418us 0.000us\nview 5.920us 0.000us 1 5.920us 0.000us\nViewBackward 3.909us 0.000us 1 3.909us 0.000us\nview 2.501us 0.000us 1 2.501us 0.000us\nadd 3.646us 0.000us 1 3.646us 0.000us\n\n",
"name": "stdout"
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "%%timeit\nm = SimpleModel(Sys.random(100).sys)",
"execution_count": 42,
"outputs": [
{
"output_type": "stream",
"text": "3.51 ms ± 305 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n",
"name": "stdout"
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "%%timeit\nm = SimpleMultModel(Sys.random(100).sys)",
"execution_count": 43,
"outputs": [
{
"output_type": "stream",
"text": "1.4 ms ± 25.7 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n",
"name": "stdout"
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "import scipy.optimize",
"execution_count": 45,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "def min_system(coords):\n coords = coords.reshape((-1, 2))\n \n m = SimpleMultModel(coords)\n \n return (m.total_score.detach().numpy(), m.grads.numpy().reshape(-1))",
"execution_count": 48,
"outputs": []
},
{
"metadata": {
"trusted": true,
"scrolled": false
},
"cell_type": "code",
"source": "%%timeit\nscipy.optimize.minimize(min_system, Sys.random(200).sys.reshape(-1), jac=True)",
"execution_count": 53,
"outputs": [
{
"output_type": "stream",
"text": "The slowest run took 30.34 times longer than the fastest. This could mean that an intermediate result is being cached.\n7.44 s ± 3.12 s per loop (mean ± std. dev. of 7 runs, 1 loop each)\n",
"name": "stdout"
}
]
}
],
"metadata": {
"kernelspec": {
"name": "conda-env-tmol_toy-py",
"display_name": "Python [conda env:tmol_toy]",
"language": "python"
},
"language_info": {
"nbconvert_exporter": "python",
"name": "python",
"file_extension": ".py",
"codemirror_mode": {
"version": 3,
"name": "ipython"
},
"pygments_lexer": "ipython3",
"version": "3.5.4",
"mimetype": "text/x-python"
},
"toc": {
"nav_menu": {},
"number_sections": true,
"sideBar": true,
"skip_h1_title": false,
"toc_cell": false,
"toc_position": {},
"toc_section_display": "block",
"toc_window_display": false
},
"gist": {
"id": "95b1363471dc14b57f50c57635ceb80d",
"data": {
"description": "workspace/tmol_toy/simple_2d_torch.ipynb",
"public": true
}
},
"_draft": {
"nbviewer_url": "https://gist.github.com/95b1363471dc14b57f50c57635ceb80d"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
@aleaverfay
Copy link

Silly comment, Alex, but you have a "3" which should be a "2" in the LJ equation -- (rm/r)^12 - 2(rm/r)^6. If it's ok with you, I'm going to screen shot your code and show what it looks like to the folks at WRC next week.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment