Skip to content

Instantly share code, notes, and snippets.

@terasakisatoshi
Last active January 25, 2018 14:56
Show Gist options
  • Save terasakisatoshi/4acf5752da26d8386d3257027793c263 to your computer and use it in GitHub Desktop.
Save terasakisatoshi/4acf5752da26d8386d3257027793c263 to your computer and use it in GitHub Desktop.
binary GCD (aka Stein's) algorithm (Python implementation)
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Stein's algorithm implementation"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Julia"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# binary GCD (aka Stein's) algorithm\n",
"# about 1.7x (2.1x) faster for random Int64s (Int128s)\n",
"function gcd(a::T, b::T) where T<:Union{Int64,UInt64,Int128,UInt128}\n",
" a == 0 && return abs(b)\n",
" b == 0 && return abs(a)\n",
" za = trailing_zeros(a)\n",
" zb = trailing_zeros(b)\n",
" k = min(za, zb)\n",
" u = unsigned(abs(a >> za))\n",
" v = unsigned(abs(b >> zb))\n",
" while u != v\n",
" if u > v\n",
" u, v = v, u\n",
" end\n",
" v -= u\n",
" v >>= trailing_zeros(v)\n",
" end\n",
" r = u << k\n",
" # T(r) would throw InexactError; we want OverflowError instead\n",
" r > typemax(T) && throw(OverflowError())\n",
" r % T\n",
"end"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Cython"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"%load_ext Cython"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"%%cython\n",
"from libc.math cimport sqrt\n",
"\n",
"cdef int trailing_zeros(int n):\n",
" cdef int order = 0\n",
" while (n&1)==0:\n",
" order += 1\n",
" n //= 2\n",
" return order\n",
"\n",
"cdef int gcd(int a, int b):\n",
" if a == 0:\n",
" return abs(b)\n",
" if b == 0:\n",
" return abs(a)\n",
" cdef int za = trailing_zeros(a)\n",
" cdef int zb = trailing_zeros(b)\n",
" cdef unsigned int k = min(za, zb)\n",
" cdef unsigned int u,v\n",
" u = abs(a >> za)\n",
" v = abs(b >> zb)\n",
"\n",
" while u != v:\n",
" if u > v:\n",
" u, v = v, u\n",
" v -= u\n",
" v >> trailing_zeros(v)\n",
"\n",
" r = u << k\n",
" return r\n",
"\n",
"cpdef double calc_pi_by_gcd(int N):\n",
" cdef unsigned int s = 0\n",
" for a in range(1, N+1):\n",
" for b in range(1, N+1):\n",
" if gcd(a, b) == 1:\n",
" s += 1\n",
" cdef double pi = sqrt(6.0*N**2/s)\n",
" return pi"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Wall time: 12.9 s\n"
]
},
{
"data": {
"text/plain": [
"3.141534239016629"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%time calc_pi_by_gcd(10000)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Numba"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"import math\n",
"from numba import jit\n",
"import time\n",
"\n",
"@jit('i8(i8)',nopython=True)\n",
"def trailing_zeros(n):\n",
" order = 0\n",
" while (n & 1) == 0:\n",
" order += 1\n",
" n //= 2\n",
" return order\n",
"\n",
"@jit('f8(i8,i8)',nopython=True)\n",
"def gcd(a, b):\n",
" if a == 0:\n",
" return abs(b)\n",
" if b == 0:\n",
" return abs(a)\n",
" za = trailing_zeros(a)\n",
" zb = trailing_zeros(b)\n",
" k = min(za, zb)\n",
"\n",
" u = abs(a >> za)\n",
" v = abs(b >> zb)\n",
"\n",
" while u != v:\n",
" if u > v:\n",
" u, v = v, u\n",
" v -= u\n",
" v >> trailing_zeros(v)\n",
"\n",
" r = u << k\n",
" return r\n",
"\n",
"\n",
"@jit('f8(i8)',nopython=True)\n",
"def calc_pi_by_gcd(N):\n",
" s = 0\n",
" for a in range(1, N+1):\n",
" for b in range(1, N+1):\n",
" if gcd(a, b) == 1:\n",
" s += 1\n",
" pi = math.sqrt(6*N**2/s)\n",
" return pi"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Wall time: 26.8 s\n"
]
},
{
"data": {
"text/plain": [
"3.141534239016629"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%time calc_pi_by_gcd(10000)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Numba (with Naive gcd)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"import math\n",
"from numba import jit\n",
"import numpy as np\n",
"\n",
"\n",
"@jit('i8(i8,i8)', nopython=True)\n",
"def gcd(a, b):\n",
" \"\"\"\n",
" Reference:\n",
" https://qiita.com/wrist/items/889696a507fbc8b392e4\n",
" http://swdrsker.hatenablog.com/entry/2017/12/20/235201\n",
" https://docs.scipy.org/doc/numpy-1.13.0/reference/generated/numpy.remainder.html\n",
" \"\"\"\n",
" while b != 0:\n",
" #a, b = b, a % b\n",
" a, b = b, np.remainder(a, b)\n",
" return a\n",
"\n",
"\n",
"@jit('f8(i8)', nopython=True)\n",
"def calc_pi_by_gcd(N):\n",
" s = 0\n",
" for a in range(1, N+1):\n",
" for b in range(1, N+1):\n",
" if gcd(a, b) == 1:\n",
" s += 1\n",
" pi = math.sqrt(6*N**2/s)\n",
" return pi"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Wall time: 10 s\n"
]
},
{
"data": {
"text/plain": [
"3.141534239016629"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%time calc_pi_by_gcd(10000)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment