Created
February 14, 2022 19:41
-
-
Save agoose77/fd6b6a0cdc41fb361e4b33b950cb8e80 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "60788a7a-f955-40f2-8456-ffc33b262f0b", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import awkward as ak\n", | |
"import numpy as np\n", | |
"import numba as nb" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "c10a775e-2567-4c3c-9a15-b886f74d0115", | |
"metadata": {}, | |
"source": [ | |
"Create some test data!" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "1c9f8f1b-f407-417d-922b-845ce6a90c8b", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"x = ak.Array(\n", | |
" [\n", | |
" [\n", | |
" {\"id\": 0, \"charge\": -1},\n", | |
" {\"id\": 1, \"charge\": 1},\n", | |
" {\"id\": 2, \"charge\": 1},\n", | |
" {\"id\": 3, \"charge\": -1},\n", | |
" {\"id\": 4, \"charge\": -1},\n", | |
" ],\n", | |
" [\n", | |
" {\"id\": 0, \"charge\": -1},\n", | |
" {\"id\": 1, \"charge\": 1},\n", | |
" {\"id\": 2, \"charge\": 1},\n", | |
" {\"id\": 3, \"charge\": -1},\n", | |
" ],\n", | |
" [\n", | |
" {\"id\": 0, \"charge\": -1},\n", | |
" {\"id\": 1, \"charge\": 1},\n", | |
" {\"id\": 2, \"charge\": 1},\n", | |
" {\"id\": 3, \"charge\": 1},\n", | |
" {\"id\": 4, \"charge\": -1},\n", | |
" {\"id\": 5, \"charge\": -1},\n", | |
" ],\n", | |
" ]\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "52fae68d-083b-4773-9f7c-151dd5fdf59e", | |
"metadata": {}, | |
"source": [ | |
"Split the positive and negative charges" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"id": "d8ec4b04-2e1b-4159-987b-e7025b4c21d2", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"pos = x[x.charge > 0]\n", | |
"neg = x[x.charge < 0]" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "e892532e-cb31-4056-a7fd-7bc466f72616", | |
"metadata": {}, | |
"source": [ | |
"Implement `itertools.permutations` using a mutable index array" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"id": "3892c99e-f480-4377-b611-15bceb73b9cd", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"@nb.njit\n", | |
"def _permutations(n, r):\n", | |
" if r > n:\n", | |
" return\n", | |
" \n", | |
" index = np.arange(n)\n", | |
" cycle = np.arange(n, n-r, -1)\n", | |
" yield index[:r]\n", | |
" while n:\n", | |
" for i in range(r-1, -1, -1):\n", | |
" cycle[i] -= 1\n", | |
" if cycle[i] == 0:\n", | |
" index_i = index[i]\n", | |
" index[i:-1] = index[i+1:]\n", | |
" index[-1] = index_i\n", | |
" cycle[i] = n - i\n", | |
" else:\n", | |
" j = cycle[i]\n", | |
" index[i], index[-j] = index[-j], index[i]\n", | |
" yield index[:r]\n", | |
" break\n", | |
" else:\n", | |
" return " | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "b2633d05-a90b-4792-9df4-937a662d78f7", | |
"metadata": {}, | |
"source": [ | |
"Define an `argpermutations` Numba implementation" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"id": "5f97a841-6425-4548-a1c8-da4fc1128249", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"@nb.njit\n", | |
"def argpermutations(left, right, builder):\n", | |
" for i in range(len(left)):\n", | |
" builder.begin_list()\n", | |
" \n", | |
" p = left[i]\n", | |
" q = right[i]\n", | |
" \n", | |
" n_p, n_q = len(p), len(q)\n", | |
" if n_p > n_q:\n", | |
" for p in _permutations(n_p, n_q):\n", | |
" builder.begin_list()\n", | |
" for i, j in enumerate(p):\n", | |
" builder.begin_tuple(2)\n", | |
" builder.index(0).integer(j)\n", | |
" builder.index(1).integer(i)\n", | |
" builder.end_tuple()\n", | |
" builder.end_list()\n", | |
" else:\n", | |
" for p in _permutations(n_q, n_p):\n", | |
" builder.begin_list()\n", | |
" for i, j in enumerate(p):\n", | |
" builder.begin_tuple(2)\n", | |
" builder.index(0).integer(i)\n", | |
" builder.index(1).integer(j)\n", | |
" builder.end_tuple()\n", | |
" builder.end_list()\n", | |
" \n", | |
" builder.end_list()\n", | |
" return builder" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "e7e66db5-0d92-41e2-ba64-81728b5e7aae", | |
"metadata": {}, | |
"source": [ | |
"Compute the indices for the positive and negative arrays" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"id": "4b1e6aae-4607-4ba4-9657-18e8744fd36c", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/home/angus/.mambaforge/envs/texat/lib/python3.9/site-packages/awkward/_connect/_numba/__init__.py:30: DeprecationWarning: distutils Version classes are deprecated. Use packaging.version instead.\n", | |
" if not checked_version and distutils.version.LooseVersion(\n", | |
"/home/angus/.mambaforge/envs/texat/lib/python3.9/site-packages/awkward/_connect/_numba/__init__.py:32: DeprecationWarning: distutils Version classes are deprecated. Use packaging.version instead.\n", | |
" ) < distutils.version.LooseVersion(\"0.50\"):\n" | |
] | |
} | |
], | |
"source": [ | |
"ix = argpermutations(pos, neg, ak.ArrayBuilder()).snapshot()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "09b72778-1f7c-4f72-b45d-72e333343dea", | |
"metadata": {}, | |
"source": [ | |
"`ix` has datashape `N * var * var * (int64, int64)`, where the first `var` is the number of permutations, and the second `var` is the configuration indices. Therefore, we need to add a dimension to our positive and negative arrays, and grow these new dimensions to match the number of configurations. \n", | |
"\n", | |
"Here we add the new dimension at the location of the first `var` mentioned above to allow broadcasting against the configuration dimension, and broadcast this array with an array of shape `N * var * 1`, where `var` is the number of configurations. We obtain this second array by indexing with a 1D array to pull out the first item." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"id": "87a4a32c-a253-464d-9691-4f28d43be3f6", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"pos_, _ = ak.broadcast_arrays(pos[..., np.newaxis, :], ix.slot0[..., [0]])\n", | |
"neg_, _ = ak.broadcast_arrays(neg[..., np.newaxis, :], ix.slot1[..., [0]])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"id": "d2b31a05-ce4d-45db-8a5c-3c32c946a449", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"pair = ak.zip((pos_[ix.slot0], neg_[ix.slot1]))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"id": "08d1037f-606e-451a-9d6f-e3de31cdbd66", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"[[[({'id': 1, 'charge': 1}, {'id': 0, 'charge': -1}),\n", | |
" ({'id': 2, 'charge': 1}, {'id': 3, 'charge': -1})],\n", | |
" [({'id': 1, 'charge': 1}, {'id': 0, 'charge': -1}),\n", | |
" ({'id': 2, 'charge': 1}, {'id': 4, 'charge': -1})],\n", | |
" [({'id': 1, 'charge': 1}, {'id': 3, 'charge': -1}),\n", | |
" ({'id': 2, 'charge': 1}, {'id': 0, 'charge': -1})],\n", | |
" [({'id': 1, 'charge': 1}, {'id': 3, 'charge': -1}),\n", | |
" ({'id': 2, 'charge': 1}, {'id': 4, 'charge': -1})],\n", | |
" [({'id': 1, 'charge': 1}, {'id': 4, 'charge': -1}),\n", | |
" ({'id': 2, 'charge': 1}, {'id': 0, 'charge': -1})],\n", | |
" [({'id': 1, 'charge': 1}, {'id': 4, 'charge': -1}),\n", | |
" ({'id': 2, 'charge': 1}, {'id': 3, 'charge': -1})]],\n", | |
" [[({'id': 1, 'charge': 1}, {'id': 0, 'charge': -1}),\n", | |
" ({'id': 2, 'charge': 1}, {'id': 3, 'charge': -1})],\n", | |
" [({'id': 1, 'charge': 1}, {'id': 3, 'charge': -1}),\n", | |
" ({'id': 2, 'charge': 1}, {'id': 0, 'charge': -1})]],\n", | |
" [[({'id': 1, 'charge': 1}, {'id': 0, 'charge': -1}),\n", | |
" ({'id': 2, 'charge': 1}, {'id': 4, 'charge': -1}),\n", | |
" ({'id': 3, 'charge': 1}, {'id': 5, 'charge': -1})],\n", | |
" [({'id': 1, 'charge': 1}, {'id': 0, 'charge': -1}),\n", | |
" ({'id': 2, 'charge': 1}, {'id': 5, 'charge': -1}),\n", | |
" ({'id': 3, 'charge': 1}, {'id': 4, 'charge': -1})],\n", | |
" [({'id': 1, 'charge': 1}, {'id': 4, 'charge': -1}),\n", | |
" ({'id': 2, 'charge': 1}, {'id': 0, 'charge': -1}),\n", | |
" ({'id': 3, 'charge': 1}, {'id': 5, 'charge': -1})],\n", | |
" [({'id': 1, 'charge': 1}, {'id': 4, 'charge': -1}),\n", | |
" ({'id': 2, 'charge': 1}, {'id': 5, 'charge': -1}),\n", | |
" ({'id': 3, 'charge': 1}, {'id': 0, 'charge': -1})],\n", | |
" [({'id': 1, 'charge': 1}, {'id': 5, 'charge': -1}),\n", | |
" ({'id': 2, 'charge': 1}, {'id': 0, 'charge': -1}),\n", | |
" ({'id': 3, 'charge': 1}, {'id': 4, 'charge': -1})],\n", | |
" [({'id': 1, 'charge': 1}, {'id': 5, 'charge': -1}),\n", | |
" ({'id': 2, 'charge': 1}, {'id': 4, 'charge': -1}),\n", | |
" ({'id': 3, 'charge': 1}, {'id': 0, 'charge': -1})]]]" | |
] | |
}, | |
"execution_count": 9, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"pair.tolist()" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.9.10" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment