Skip to content

Instantly share code, notes, and snippets.

@aristus
Created May 28, 2015 20:56
Show Gist options
  • Save aristus/656c0172b9671ae53d5c to your computer and use it in GitHub Desktop.
Save aristus/656c0172b9671ae53d5c to your computer and use it in GitHub Desktop.
#!/usr/bin/env python2.7
# Note: this script is for manual benchmarking!
# ./benchmark init-spatial
# time ./benchmark write-spatial
# ./benchmark read-spatial
from memsql.common import connection_pool
import random, math, sys
from time import time, sleep
from multiprocessing import Pool
AGGREGATORS = ['127.0.0.1']
WRITE_THREADS = 8
READ_THREADS = 16
NUM_RECORDS = int(200 * (10**6))
CHUNK_SIZE = 10000 ## when doing inserts "live" into memsql
DISK_CHUNK_SIZE = 100000 ## when writing files to disk
LOAD_DATA_BATCH_SIZE = 1000000 ## number of records per file on disk
NUM_LOAD_DATA_FILES = NUM_RECORDS / LOAD_DATA_BATCH_SIZE
NUM_LOCATIONS_READ=1 ## for non-spatial reads test
POLYGON_MIN = 9000 #6500 ## for spatial reads test
POLYGON_MAX = (POLYGON_MIN * 1.2)
POLY_POINTS = 12
READ_ITERATIONS = 1
def cross (x,y):
if len(x) == 2:
assert len(y) == 2
return x[0]*y[1]-y[0]*x[1]
return (x[1]*y[2]-y[1]*x[2],
x[2]*y[0]-y[2]*x[0],
x[0]*y[1]-y[0]*x[1])
def dot (x,y):
return sum([xi * yi for xi, yi in zip(list(x),list(y))])
def norm2(x):
return dot(x,x)
def norm(x):
return math.sqrt(norm2(x))
def normalize(x):
return tuple([a / norm(x) for a in list(x)])
def random_vector():
d2 = 0
while d2 < 0.1 or d2 > 1.0:
x = random.random() * 2 - 1
y = random.random() * 2 - 1
z = random.random() * 2 - 1
d2 = x*x + y*y + z*z
d = math.sqrt(d2)
return (x/d, y/d, z/d)
def mul(v, a):
(x, y, z) = v
return (x*a, y*a, z*a)
def lonlat(v, fmt="%.8f %.8f"):
(x, y, z) = v
lat = math.atan2(z, math.sqrt(x*x + y*y)) * (180.0 / math.pi)
lon = math.atan2(y, x) * (180.0 / math.pi)
return fmt % (lon, lat)
def ortho(v):
(x, y, z) = v
ax = abs(x)
ay = abs(y)
az = abs(z)
if ax < ay:
if ax < az:
return (0, -z, y)
else:
return (-y, x, 0)
if ay < az:
return (-z, 0, x)
else:
return (-y, x, 0)
def random_point(delim=',', quote="'"):
x = delim.join((lonlat(random_vector(), quote+"point(%.8f %.8f)"+quote), str(random.randint(-10000, 100000)), str(random.randint(1, 100000000)), str(int(time()))))
return x
def random_row(delim=','):
x = delim.join((str(random.randint(1, 10000000)), str(random.randint(-10000, 100000)), str(random.randint(1, 100000000)), str(int(time()))))
return x
def combine(zero, i, j, a):
v = (0, 0, 0)
s = math.sin(a)
c = math.cos(a)
vv = [0, 0, 0]
for k in xrange(3):
vv[k] = zero[k] + i[k]*s + j[k]*c
return normalize(tuple(vv))
def random_polygon(size):
angles = [(random.random() * 2 - 1) * math.pi for i in xrange(POLY_POINTS)]
angles.sort()
height = 1 - 2 * (math.sin(0.5 * (size / 6378137)))**2
radius = math.sqrt(1 - height**2)
center = random_vector()
o = normalize(ortho(center))
i = mul(o, radius)
j = mul(normalize(cross(o, center)), radius)
zero = mul(center, height)
points = [lonlat(combine(zero, i, j, a)) for a in angles]
return "POLYGON((%s, %s))" % (', '.join(points), points[0])
def measure(db, name, query, args):
queries = [query % a for a in args]
a = time()
rows = len(db.query(queries[0]))
#assert len(db.query("show warnings")) == 0
b = time()
if len(queries) > 1:
rows = 0
a = time()
for query in queries:
rows += len(db.query(query))
#assert len(db.query("show warnings")) == 0
b = time()
if rows == 0:
return False
if random.random() <= 0.05:
print "%s: %.2f ms, rows %.0f, count %d" % (name, float((b - a) * 1000 / len(queries)), float(rows) / len(queries), len(queries))
sys.stdout.flush()
return True
def worker(a):
sleep(random.random())
id, rows = a
global AGGREGATORS
agg = AGGREGATORS[id % len(AGGREGATORS)]
global pool
ts = start = time()
db = pool.connect(agg, '3306', 'root', '', 'db_select_perf')
for i in xrange(rows / CHUNK_SIZE):
db.query("INSERT IGNORE INTO db_select_perf.terrain_points VALUES (%s)" % "),(".join((random_point() for q in xrange(CHUNK_SIZE))))
if (i % 100) == 0:
stop = time()
print (i+1) * WRITE_THREADS * CHUNK_SIZE, "total, ", int((CHUNK_SIZE * 100 * WRITE_THREADS)/(stop-ts)), "per sec"
ts=time()
print (i+1) * WRITE_THREADS * CHUNK_SIZE, "total, ", int(((i+1) * WRITE_THREADS * CHUNK_SIZE)/(stop-start)), "per sec"
def worker_nonspatial(a):
sleep(random.random())
id, rows = a
global AGGREGATORS
id = AGGREGATORS[id % len(AGGREGATORS)]
global pool
ts = time()
db = pool.connect(id, '3306', 'root', '', 'db_select_perf')
for i in xrange(rows / CHUNK_SIZE):
db.query("INSERT INTO db_select_perf.terrain_points_int VALUES (%s)" % "),(".join((random_row() for q in xrange(CHUNK_SIZE))))
if (i % 100) == 0:
stop = time()
print i * WRITE_THREADS * CHUNK_SIZE, "total, ", int((CHUNK_SIZE * 100 * WRITE_THREADS)/(stop-ts)), "per sec"
ts=time()
def worker_print(a):
thread_id, num_rows = a
out = open(FILE_PATH + ('/terrain_points_int-%04d.tsv' % thread_id), 'wc')
for i in xrange(num_rows / DISK_CHUNK_SIZE):
print >> out, '\n'.join(random_row('\t') for q in xrange(DISK_CHUNK_SIZE))
def worker_print_spatial(a):
thread_id, num_rows = a
out = open(FILE_PATH + ('/terrain_points-%04d.tsv' % thread_id), 'wc')
for i in xrange(num_rows / DISK_CHUNK_SIZE):
print >> out, '\n'.join(random_point('\t', '') for q in xrange(DISK_CHUNK_SIZE))
def select_worker(a):
sleep(random.random())
db = pool.connect(a, '3306', 'root', '', 'db_select_perf')
while True:
size = random.randrange(POLYGON_MIN, POLYGON_MAX)
rsize = size * 1.0e-7 * math.pi / 3 * 6378137
args = [random_polygon(rsize) for i in xrange(READ_ITERATIONS)]
measure(db, "", "SELECT * FROM db_select_perf.terrain_points with (index=location, resolution=6) WHERE geography_intersects(location, '%s')", args)
def select_worker_approx(a):
sleep(random.random())
db = pool.connect(a, '3306', 'root', '', 'db_select_perf')
while True:
size = random.randrange(POLYGON_MIN, POLYGON_MAX)
rsize = size * 1.0e-7 * math.pi / 3 * 6378137
args = [random_polygon(rsize) for i in xrange(READ_ITERATIONS)]
measure(db, "", "SELECT * FROM db_select_perf.terrain_points with (index=location, resolution=6) WHERE approx_geography_intersects(location, '%s')", args)
def select_worker2(a):
sleep(random.random())
db = pool.connect(a, '3306', 'root', '', 'db_select_perf')
while True:
args = [','.join((str(random.randint(1, 10000000)) for _ in xrange(NUM_LOCATIONS_READ))) for _ in xrange(READ_ITERATIONS)]
measure(db, "", "SELECT * FROM db_select_perf.terrain_points_int WHERE location in (%s)", args)
pool = None
db = None
def init():
global db, pool
pool = connection_pool.ConnectionPool()
db = pool.connect('127.0.0.1', '3306', 'root', '', '')
#db = pool.connect('127.0.0.1', '3306', 'root', '', '', {'ssl': '../certs/ca-cert.pem'})
random.seed(time())
return db
if __name__ == "__main__":
cmd = sys.argv[1]
if cmd == 'init-spatial':
db = init()
db.query("create database if not exists db_select_perf")
db.query("use db_select_perf")
db.query("flush connection pools")
db.query("drop table if exists terrain_points")
db.query("""CREATE TABLE terrain_points (
location geographypoint DEFAULT 'Point(0 0)',
elevation int unsigned NOT NULL,
ent_id int unsigned NOT NULL,
time_sec int unsigned NOT NULL,
shard key (location, ent_id, time_sec)
);""")
elif cmd == 'init':
db = init()
db.query("create database if not exists db_select_perf")
db.query("use db_select_perf")
db.query("flush connection pools")
db.query("drop table if exists terrain_points_int")
db.query("""CREATE TABLE terrain_points_int (
location bigint unsigned NOT NULL,
elevation int unsigned NOT NULL,
ent_id int unsigned NOT NULL,
time_sec int unsigned NOT NULL,
shard key (location, ent_id, time_sec)
);""")
elif cmd == 'write-spatial':
db = init()
Pool(processes=WRITE_THREADS).map(worker, enumerate([NUM_RECORDS / WRITE_THREADS] * WRITE_THREADS))
elif cmd == 'write':
init()
Pool(processes=WRITE_THREADS).map(worker_nonspatial, enumerate([NUM_RECORDS / WRITE_THREADS] * WRITE_THREADS))
## dump data files to disk
elif cmd == 'generate-files':
FILE_PATH = sys.argv[2]
Pool(processes=WRITE_THREADS).map(worker_print, enumerate([LOAD_DATA_BATCH_SIZE] * NUM_LOAD_DATA_FILES))
elif cmd == 'generate-files-spatial':
FILE_PATH = sys.argv[2]
Pool(processes=WRITE_THREADS).map(worker_print_spatial, enumerate([LOAD_DATA_BATCH_SIZE] * NUM_LOAD_DATA_FILES))
elif cmd == 'read':
db = init()
Pool(processes=READ_THREADS).map(select_worker2, AGGREGATORS * (READ_THREADS / len(AGGREGATORS)))
elif cmd == 'read-spatial':
db = init()
Pool(processes=READ_THREADS).map(select_worker, AGGREGATORS * (READ_THREADS / len(AGGREGATORS)))
elif cmd == 'read-spatial-approx':
db = init()
Pool(processes=READ_THREADS).map(select_worker_approx, AGGREGATORS * (READ_THREADS / len(AGGREGATORS)))
else:
print "No command given."
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment