Skip to content

Instantly share code, notes, and snippets.

@charmoniumQ
Created June 16, 2020 21:15
Show Gist options
  • Save charmoniumQ/a98d3bfbdbe7fbb3bfdd28790c5e3c03 to your computer and use it in GitHub Desktop.
Save charmoniumQ/a98d3bfbdbe7fbb3bfdd28790c5e3c03 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"import numpy as np\n",
"from sympy import init_session; init_session(quiet=True)\n",
"from sympy.stats import *\n",
"p1, p2, n = symbols('p_1 p_2 n', integer=True, positive=True)\n",
"t = symbols('t', real=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Data parallelism\n",
"\n",
"In data parallelism, the set of inferences is parallelized $p_1$ ways, and each inference is paralellized $p_2$ ways. The total number cores is $p = p_1 p_2$.\n",
"\n",
"| $p_1$ | $p_2$ | John's experiment name |\n",
"| ----- | ----- | ---------------------- |\n",
"| 12 | 1 | dataparallel-1 |\n",
"| 2 | 6 | dataparallel-1.5 |\n",
"| 1 | 12 | dataparallel-2 |\n",
"\n",
"I propose referring to John's data parallel experiments as \"dataparallel-($p_1$, $p_2$)\".\n",
"\n",
"Assume each inference can be divided into $p$ parts whose runtime is normally distributed with a mean of $\\mu$ and a standard deviation of $\\sigma$ (denoted $N(\\mu, \\sigma^2)$).\n",
"\n",
"**Per-inference latency** (what John calls \"end-to-end\"): There are $p$ pieces to each inference, handled by $p_2$ processing units. Each processor gets $\\frac{p}{p_2} = p_1$ units which each take $N(\\mu, \\sigma)$. Using the formula for the sum of normally distributed variables (see [[1][1]]), we find that each processing unit spends time $N \\left(p_1\\mu, p_1\\sigma^2 \\right)$.\n",
"\n",
"[1]: https://en.wikipedia.org/wiki/Sum_of_normally_distributed_random_variables"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"115 ms ± 3.62 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n",
"197 µs ± 45.8 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n"
]
}
],
"source": [
"mu, sigma = symbols('\\mu \\sigma', real=True, positive=True)\n",
"inference = Normal(\"X\", p1 * mu, p2 * sigma**2)\n",
"\n",
"# some sample parameters\n",
"# The difference depends a lot on n\n",
"params = {p1: 4, p2: 3, n: 96, mu: 1, sigma: 1}\n",
"\n",
"# It is much faster (1000x) to define the expected value by hand.\n",
"%timeit E(inference).subs(params)\n",
"%timeit (p1 * mu).subs(params)\n",
"\n",
"E_inference = p1 * mu\n",
"std_inference = sqrt(p2) * sigma"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Total latency:** There are $n$ samples to run inference on. The next sample cannot be processed until the slowest processing unit is done. The expected value of this time is the maximum of $n$ random variables each drawn from $N \\left(p_1\\mu, p_1\\sigma^2 \\right)$ (denoted $\\mathbb E \\left[ \\max_n \\left( N \\left(p_1\\mu, p_1\\sigma^2 \\right) \\right) \\right]$). There are a couple of ways of calculating this:\n",
"\n",
"- Numerically evaluating the following integral (slowest but most accurate), where $\\Phi(t; N(\\mu, \\sigma))$ is the CDF of $t$, given the distribution $N(\\mu, \\sigma)$. See [[2][2]] for a derivation.\n",
"\n",
" $$ \\mathbb E \\left[ \\max_n \\left( N \\left(p_1\\mu, p_1\\sigma^2 \\right) \\right) \\right] = \\int_{-\\infty}^{\\infty} t \\frac{d}{dt}\\Phi(t; N(p_1 \\mu, p_1 \\sigma^2))^n dt$$\n",
"\n",
"- Randomly sampling the distribution (10x faster but some accuracy loss).\n",
"\n",
"- Using an analytic approximation given in [[3][3]] (50000x faster, but accuracy loss and only works for normal distributions (which we might change later)):\n",
"\n",
" $$ \\mathbb E \\left[ \\max_n \\left( N \\left(p_1\\mu, p_1\\sigma^2 \\right) \\right) \\right] = \\int_{-\\infty}^{\\infty} t \\frac{d}{dt}\\Phi(t; N(p_1 \\mu, p_1 \\sigma^2))^n dt \\approx p_1 \\mu + p_1 \\sigma^2 \\sqrt(2 \\log n)$$\n",
"\n",
"\n",
"[2]: https://math.stackexchange.com/a/473237/24891\n",
"[3]: https://math.stackexchange.com/a/510580/24891"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"21.1 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)\n"
]
}
],
"source": [
"def E_max_rv_exact(n, RV):\n",
" \"\"\"Returns the expected value of the max of a sample of each RV in RVs\"\"\"\n",
" t = symbols('t', real=True)\n",
" pdf = diff(cdf(RV)(t)**n, t)\n",
" return Integral(t * pdf, (t, -oo, oo))\n",
"\n",
"%timeit -n 1 -r 1 E_max_rv_exact(n, inference).subs(params).evalf(2)\n",
"true = float(E_max_rv_exact(n, inference).subs(params).evalf(2))"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1.32 s ± 150 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n",
"0.03 % error\n"
]
}
],
"source": [
"def E_max_rv_estimate(n, RV):\n",
" m = 20\n",
" return sum(max(float(sample(RV)) for _ in range(n)) for _ in range(m)) / m\n",
"\n",
"results = []\n",
"%timeit results.append(E_max_rv_estimate(n.subs(params), inference.subs(params)))\n",
"print('{:.2f} % error'.format(np.sqrt(np.mean((np.array(results) - true)**2)) / true))"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"456 µs ± 43.1 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n",
"0.20 % error\n"
]
}
],
"source": [
"def E_max_rv_norm_approx():\n",
" return E_inference + std_inference * sqrt(2 * log(n))\n",
"\n",
"results = []\n",
"%timeit results.append(E_max_rv_norm_approx().subs(params).evalf(2))\n",
"print('{:.2f} % error'.format(np.sqrt(np.mean((np.array(results, dtype=float) - true)**2)) / true))"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"# The winner was E_max_rv_norm_approx\n",
"E_total = E_max_rv_norm_approx()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Tuning the model\n",
"\n",
"We have empirical data for (1, 12) and (12, 1). I will show what my model produces for these values and tweak the parameters to match the empirical data."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAgoAAAAmCAYAAAClMLMxAAAACXBIWXMAAA7EAAAOxAGVKw4bAAASEUlEQVR4Ae2d7bXdtBKGT7JSwEmogNABhAqS20EgFSR0ACu/wj8WdBCogEAHQAUEOgAqANIB9328NY4sy7bkj22fndFaPpL1MTN6pRmNZG+fWy9evHh8dXX1g644fPbll19+G2cMpVXvpcp+UvzjUB3PdwQcgT4C0plr5f6i66HSb/o1jpHjOn6McXApHIFaBGptTND1Zwmfu7eCo/BEFT5JCidvA9F7c9pOEvcKjsA7gIB050N1E0f9I6UP5yy4jr8Dk9C7eNEILLExavuvwHn/9lyERACv44HiagdjLk9v5whcGgLSn9/Vp691fXe0vrmOH21EXB5HoB6BNWzMLEdBjNkFYdwe1ovtLRwBRyBGQPrEY75rxTwGPERwHT/EMLgQjsAqCCy1MXdmSsHu51sxP9xR6cz+eLMzIKD5whH7YRbDM3R5koUwuRUqcTL3l+5/PoheuY5Pjp5XOBoCbmP6I7KGjal2FMQUQ8+Jgp8m9MfkrDkaC8bht1Km0YQpbbJavSDr1Z4yrNaZDQgJlze6Xok0C/Suj/Mkh+v4BmN8E0lqLriNuYkDl5F5iY2pdhTEn0cO38A0I4tnnRcBxoKX4HjOffSArF8cXcid5QOjPzSe93X9uaMsruM7gn8w1m5jDjYgC8WZZWOq3lGQ8cK7vK/rq4XCevOFCISxuFJ8eCfhJsm6cFgWNRdOOAc/6/psEaEFjcNYuY4vwPBSmt4kvb1Jsu45P+bamCpHQR18rusoz1D3xPsIvBkLvMObEJD1kKcJUhwWxU3CTNq8x5H+jnkT+QaIuo4PAPMOZruNWWHQZ9qBIs4zaVfbmFpHgWeXMPGwIwJhcnA8ze7z0CGS9XAnH5Ltc4HHKdlWgTGCR01gTPf8BYTreM1oXWjdSG/dxiwY40uxMcXvKKjDZlAPP3EWjOuqTYOy2U76gYj/o+sL5S9dNKFpdDsyi3Z6yvD9Cvw6PCpvkKeVVbIwj3hhj508p1O7vLgnviyI7yn+RnEnKO9llHFP6afKy76To/xBvFVG/z7U9UxX6ZdO/1RdHkH8T9dZv3aKrKHfxToe2rB5yH4wSuWMs43/mjoQRF0eXUIflqPQo9CzMQceS7cxG9uYYkdB0wglv9Jk2fMlq95sPmpGUKqXijH4TVCaCf0bebqKjXFobjSuleBDV53n2LrHIGOwcUQa2oq55/pA19lDkKlz8qE8nCQWleJfa6wtuHiD4XPFH8W0Qz5yMW6NA6GYxZOfLSJzO/eVLsJb9XjxlzF/pSvrbMQyhDTj1+hbpmzLrCIdVz/AD2cPx5c2YNELqkf+6jrQYzQj4xL6MKPbRU0CNh0bc9SxDHK5jdnYxtwumjmnSix4S3fCFez2r6pJ+EjX3OfFOAXpYo6XzmLB4j03QDf3Mik0OT2IHRAMeru4zWW4oB39zckKSRaZvQIYxqcGJgeLH58kb08ZlGbOv87Ur8EbXvAsDX+oou3uS9usUa9Ix4UJP+X8RBfz+/sRxvR5Cx0YYVlWdNP6IHmX2KIyUN7WYtxSvT3qWLqNOY3bpjamxlF4JHn2XHTeTuPzpVhoueYE8OKnbml7FnKeQd+vJRpoYTA6R9K6fyxaLCyd423lc3LRnmjU8ltSX3zpN7uSjqxLaK7Y9lPJ1cEq0AbH3BzHWQD3ZiwVV+EdeMEznQtDXWpkUP1zOwtr6/jqOjAE2Ib5R+kDc6d0/syGI8zRno0RwaPg0PYtyOo2RogIC+zZZjamxlFgkuaMaDtwnugggEPA8+Y3ndy3N3OU/rma49mngV0bu7whXmn9c9wja7or2ZyvMBh1wFSedQaUb+ORO+n4Owj+IMRz8EZ3Pi0EwPRstC+FtGqqgYHxrmk3VHcLHRjitVX+JfShBpshG3NEHNzGdEd2Mxtzp8tn8o4j0ckQjDXPZe/GlXWPV8rLXe3Rbly+JL0HzzF5Jc/QS3rNLlHl7FLboHsWMBZ7FHIoPFZ57n0DFjCcEmg/0cXCRr0fhugpn0UBp4PFyxZJJdvAJ7pZEKtDoI2sHAtWhUgum2vZfqgecwn5UA4W1J90NQudyvi35zmcVKV5SbCHseqDPeX3+JOE98K9LdzVeKs9PDndyZ1kBPJtZIt1Tpa20kYJw30xeeFZqwPge1a7MdXJS+iD9VF9mW1jZuCwmX2hP5IH+m5jbHBP8WY25naXT/5Og2IGMrfbyjXCgOfqsjANGfAcnZq8PXjWyMfkZiEHy84Cqnx+Qsfz8dxpgbIbxaDOy+am/8cWeo7heJmRF+jAA0cB49AJykOOv0ImxpzrjS4mGi/48eIe7TtBeTwyMV6dsuSG9zoG+5LUbW+DXLxM+LXS9MH6QZ/ofxNCPRwDfo1AGfLDj8WeRx092ZVngUV+aDGkrc11q08MXgTru8VFeJ+aNjxztEPx20h9YCwIxud0t+Ff8TTZcnq7GmfxyepAYHB4HUbOm9gHybzUxvTmwBAOIb/avgRs3cackDbdP4SNud0b/XyGCZ0v7eey2+vs2jR5oIGRwMBvEfbgWdsPXoD7UVi0JypBqdhlcnHaYotSSvuzuJ0VKs/GhrbpbvWV6n0X1bkK6V+UzyILTX7CxwKJ8/JI6d+5lO4E5SEXyl/iAEA3laVDb+DG8LEdtVVDNpwHw4Yjx/SxDvOtccBUrzP3jEiIwWtoMXxKHbVnLjUh8HwTbuFZhXdoRwRPW4yj7NGknWSMVlqp0Pq1ErlBMjbGrQ5ENW+CDiPujepDmMOzbUw0Pmmyh0PQj2r7AuEgp9uYA9qYO+nID9zfC/lmMAeqtdkY9K/au1PCjO+YEU+aVN3O5qkJyk7d5IuZNv1WeW6HyoI6dLQa02jSgQcLTaeN7ptFWTF44fWzCKZ12KEPnSaoqAnp4komu3Pasos23FnokSN9yTDXXlUbBbZTide6hV4OD6peie7YyUdTJ/dH7Rg/FtJf03KVgTXZT3T1nBgKKgJjmp3H4sGJxPsqN6eEhRp5uMAgxihOq6gJObytjPq1i3FtfeM1J27muhpmsZlDMG0jbJnDPR2I6jEHzm03IvbTya37EOivaotEcy0b0wI0gkO1fYGo6LmNObCNudOO/HiC3RBh0nBpwG2S28J0anl6PouRWN0QLeWp9tmFT/lMXn6jm9v9WL8mY7VnceVnd4O/QFAZixRe/zPF8IwXInbKPBLohdCO/DFc453s0EKPkY55trzEo3Eq4KXMR4qRcejEgNOEOY+XTMaxfiAjgQWH9xA4prT6lJWcdtB+MAR6nfmgPKPbzF/d09745mhZX+KyexNt4rp7pIt1fI5wwmxUB1R+drtR249z9EE8OnPPZFT+YlskGrNsjMlg8QQO1fYFuqLpNuYE8CFtzG0b/BVjFsOcQ4AhsMlwvSI/SO3Bs6gLUgAU/APF7SmB0jgCucXEFqTmCB0Gqkd7HleMLUw4ZWOYNg5AxDN14mDFbj2XT1kT1J5yaLXynUpOf1WOkWjGOM4vTJuTMtkP0XutC0eFxyrs/jkCJR5yXlTcBhbEMR5txSiBExL/j5MivKP2JOFpi3FSNHg79C7FYIMjFmhcSnTgsDoMppfQhzA35tqYpvkYDiozm5azI5P2JeDsNuY0UIeyMaWOwpswydgVTQUcgt/jSmECMYns/QSO19ug8lrD3bYNiWqeKYEt7tUvFpiPFacLK4azt2ioHoslE4Qdu2ECVulxrLI6gR22KWlcwCkEuwhTXONpcVNX5chJ+1TOpjz5g6HB0bEdYFxM+ylZ4/ptWvSYM8yz3qlLxAuHgABvXtTkoz/Ny4yKS5wE2oJxDqsr0eAt6n91GfbkkYZfjE0p3mrWBnTHnKE2cyIBHpMhlney8nAF41Wi48NUkhLJVqoD1Tq8Ur8Tifu3W/ahz23bHPVlro25KsDB7IrFTWdCu1L7Qhu3MadT05ydSm16PGE2szF3Yi4jaRv41oDm6mpCUI5hSOuZkeWkgc63O6XQBuPMc+js8XqOl+WF9lU8re2WseSinyxs7ERZWOLA8f3Q4wyUBKP5THVQ6teKzYjHNNq0yjlxgA+76gZrxYzBp7qeWkXlmdMAfTvdMTlZdEf5QEd1+Nkk/YGPOSDkl5x8QGIsPFThL9DXheNgAUz4BYTxQ076St9M5n90b2lrl4uh+3GuQHlgYXPdqjCGPE5p5VG6CG8jEGLmdksjKevcij5yEFJZTrnRX9W91u1s/YlIGS/o1YT3QmWMVAf/0I9JHQh9qNLhFfuN+Lv0IeC2RzTHxpidGLVnGhd0dLZ9AQzRcBtzIBvDmBQ5Cho4FhiMwNSz5wcQVWAyYSB+5UZpDC3OAc/f/la6XSSVhjYL4rViro6xUf5UqOY5RXClck5PUC6O49MwuGCo/2BH+XNd4NI+skiJxPdqw1cYWTzNKcFwPwy04qrQo54tRhjJ7D/0iRslaXbvODLxuxTIy0I/GtSGBYG6zbjpnnnCqQBOpDmL3Ns8QE7uzUm4Iq1L2c3LmsRNUB54tf+nIWSn0ffKgGcvqD3OCHg8V3wdKoBVy9saKa8Ub2uC8cw+f7YKUWy8DYOoqJuUHEv1pyEY6MBvSsetvmFIvwic8IA/747Y6U6pDlTr8Br9Fo1d+9CgtsMf9XuOjSkdyzXsC6i4jTmIjWEwbr148eKx4ieaPKMLksp5o5tdW+9oGEIElbFI8bvPj5qMij9qgxxM4EnjGJNdwjOmk0sHmVgMW8cmV2/tvMAXI8bOdXRc1uZdQk8ysXjj+DUfZdI9i4V9+7+ExKI64ofRwiGwUxEWVhwjnBB2S6POgtohO/IOOmwqXy2ID3ixgJYuwugC439XbYr0QfVm6U/cSdGY1PG4/lpp8T273VhLdqOzpA9GYywO47uaLQr03MYMgC583MYIG+Hwr6L3bw/glMvmBbIHuYIojwWjt/uKyseSH0uoIqOYEFnCMyHVu0WeOTL1CNVkCAcWQBYxe2RT03zzupKPnSPjbKclLM5cmwfxbnawAaOGn9LNrjrkIcegMxsEpE7p7j40WRQxjjX4fKz6uReCx4SYqz8xzRIdj+uvlV6iw2v0e41+LOlDCf9VbVHQFbcxGeSFjduYBJcaR4HdRvN4IKHR3ArcayWanVOufCwvtP17rE6ubAnPHL00T/Q54bBj1LR403vx7fxb402ZzSPeLHySk10JixrOwzkCixk7K04PcgEHgB3qYFBbxhQazNdNQ+ABr5p5RN+KTztEG92r1p9Mx0d1PFN/cVaQ/ax2Y7HQCYElfUhIDd6Kx+q2SDTdxuQRdxuT4HInuR+7faVCDDDeFjveNNhpAyDXhmeatHOO95fwrJXR60cIBMOFc8CRd9GRetR8dlJ8OT3g0RbvETxRbAsk7xawYPLFyZJFlkc6ODlTpw+qsiigM7WnF+hYTZu5+pN2bErH0/pr3C/R4bX6vbQfS/qwlPfFtpce4xy5jZke4c1tTLGjoAHDQHPcjGHNOQo4CLzgxRFZVVCbOU4CPGbzrBLQKw8h0Bzza/zOdZrQyBHm2KLHMtDQxUu2n+uaO/+GcDE5+UolPIrxUV2cBELxI7y15BedKR0/Sbbu39k6vFa/V+jO7D6swPvSSbiNGRlh6cBZbEzxy4zIKqF4Js1b4J3/CjnSDy9yBByBCgSkWxjGR4qrXwiuYDNY1XV8EBovcAQuAoEaG6O61S8zAhJHk9dqbLueiwDOO+EIHAgBnPGvdpTHdXxH8J21I3AGBKptzO0aoeQg8FiBY9rmRbaatl7XEXAExhGQfnGMeKX4x/Ga25W6jm+HrVN2BPZGYK6NqXIU6KQY8Wx47K3zvbFw/o7ATUXguQR/urfwruN7j4DzdwQ2Q2CWjal2FIL4OAvfbdYVJ+wIvGMIBE+fn5nudpqQQO46ngDit47ATUZgiY2Z5SiIYfObcMX2wZ2bjJ/L7gjsioD06L4EwNM/zFc4Xcd3nRLO3BFYFYGlNmaWoxB6wHf9+QY/Rs6DI+AIzEBA+sO3H/hcLN9/KP4Z5QxWc5q4js9Bzds4AgdCYA0bYz+P5MMzceC335NfklMdnATa8s+Hqr+fEDP0tCPwLiIgvcFJ4B8qTerbHvi4ju+BuvN0BNZDoMbGqC4fb0qfFNy99d9//y2SKBgSfjJZ8jW8Rby8sSNwSQhIZzhN4J+oFX9caY/+u47vgbrzdASWI7CWjfk//ex36x8DUukAAAAASUVORK5CYII=\n",
"text/latex": [
"$\\displaystyle \\left[ \\left( \\mu, \\ \\mu + 2 \\sqrt{6} \\sigma \\sqrt{\\log{\\left(96 \\right)}}\\right), \\ \\left( 12 \\mu, \\ 12 \\mu + \\sqrt{2} \\sigma \\sqrt{\\log{\\left(96 \\right)}}\\right)\\right]$"
],
"text/plain": [
"⎡⎛ _________⎞ ⎛ _________\n",
"⎣⎝\\mu, \\mu + 2⋅√6⋅\\sigma⋅╲╱ log(96) ⎠, ⎝12⋅\\mu, 12⋅\\mu + √2⋅\\sigma⋅╲╱ log(96) \n",
"\n",
"⎞⎤\n",
"⎠⎦"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"[\n",
" (\n",
" E_inference.subs({p1: 1, p2: 12, n: 96}),\n",
" E_total .subs({p1: 1, p2: 12, n: 96}),\n",
" ),\n",
" (\n",
" E_inference.subs({p1: 12, p2: 1, n: 96}),\n",
" E_total .subs({p1: 12, p2: 1, n: 96}),\n",
" ),\n",
"]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"First, let $\\mu$ be the average of dataparallel(12, 1) / 12 and dataparallel(1, 12)\n",
"\n",
"Then, we can similarly compute $\\sigma$."
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"dp2 = (15e3, 15e6)\n",
"dp1 = (16e4, 4e6)\n",
"mu_v = np.mean([dp2[0], dp1[0] / 12])\n",
"sigma_v = np.mean(np.array([\n",
" (dp2[1]-mu_v) / np.sqrt(24), (dp1[1]-12*mu_v) / np.sqrt(2)],\n",
" dtype=float) / np.sqrt(np.log(96)))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Produce a graph\n",
"\n",
"For the parameter sweep, I will go through all possible factors of 12."
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Text(0, 0.5, 'Total latency')"
]
},
"execution_count": 25,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 900x600 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"factors = [1, 2, 3, 4, 6, 12]\n",
"n_v = 96\n",
"params_v = {n: n_v, mu: mu_v, sigma: sigma_v}\n",
"\n",
"xs = []\n",
"ys = []\n",
"for p1_v, p2_v in zip(factors, factors[::-1]):\n",
" xs.append(float(E_inference.subs({**params_v, p1: p1_v, p2: p2_v})))\n",
" ys.append(float(E_total.subs({**params_v, p1: p1_v, p2: p2_v})))\n",
"\n",
"%matplotlib inline\n",
"import matplotlib\n",
"matplotlib.rcParams['figure.dpi'] = 150\n",
"import matplotlib.pyplot as plt\n",
"\n",
"fig = plt.figure()\n",
"ax = fig.gca()\n",
"ax.set_title(\"Analytical model of NNs (128 threads/core)\")\n",
"ax.plot(xs, ys, label=\"Data parallel (analytical)\", marker=\"\")\n",
"ax.plot([75000], [6e6], label=\"Pipelined (empirical)\", marker=\"o\", linestyle=\"\")\n",
"ax.plot([dp1[0], dp2[0]], [dp1[1], dp2[1]], label=\"Data parallel (empirical)\", marker=\"o\", linestyle=\"\")\n",
"ax.legend()\n",
"ax.set_xlabel(\"Per-inference latency (AKA end-to-end latency)\")\n",
"ax.set_ylabel(\"Total latency\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.3"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment