|
import random |
|
import time |
|
import statistics |
|
import math |
|
import cffi |
|
import profile |
|
ffi = cffi.FFI() |
|
random.seed(0) |
|
|
|
class Unit(): |
|
def __init__(self, x, y, id): |
|
self.x = x |
|
self.y = y |
|
self.id = id |
|
|
|
def dist(self, other): |
|
return self.x * other.x + self.y * other.y |
|
|
|
def get_closest_id(unit, units): |
|
closest_id = -1 |
|
closest_dist = 10000 |
|
for other in units: |
|
if other.id == unit.id: |
|
continue |
|
if unit.dist(other) < closest_dist: |
|
closest_dist = unit.dist(other) |
|
closest_id = other.id |
|
return closest_id |
|
|
|
def spiral(X, Y): |
|
x = y = 0 |
|
dx = 0 |
|
dy = -1 |
|
for i in range(max(X, Y)**2): |
|
if (-X/2 < x <= X/2) and (-Y/2 < y <= Y/2): |
|
yield (x, y) |
|
if x == y or (x < 0 and x == -y) or (x > 0 and x == 1-y): |
|
dx, dy = -dy, dx |
|
x, y = x+dx, y+dy |
|
|
|
SPIRAL = [l for l in spiral(5, 5)] |
|
|
|
class SpatialHash: |
|
def __init__(self, world_width, world_height, bucket_size=4): |
|
self.world_width = world_width |
|
self.world_height = world_height |
|
self.bucket_size = bucket_size |
|
|
|
self.buckets = [] |
|
for _ in range(world_width): |
|
bucket = [] |
|
for _ in range(world_height): |
|
bucket.append([]) |
|
self.buckets.append(bucket) |
|
|
|
def add(self, unit): |
|
bucket_x = unit.x // self.bucket_size |
|
bucket_y = unit.y // self.bucket_size |
|
self.buckets[bucket_x][bucket_y].append(unit) |
|
|
|
def get_closest_id(self, unit): |
|
bucket_x = unit.x // self.bucket_size |
|
bucket_y = unit.y // self.bucket_size |
|
closest_id = -1 |
|
closest_dist = 10000 |
|
for (next_x, next_y) in SPIRAL: |
|
bx, by = bucket_x + next_x, bucket_y + next_y |
|
dx = bx * self.bucket_size - unit.x |
|
dy = bx * self.bucket_size - unit.y |
|
if dx * dx + dy * dy > closest_dist: |
|
break |
|
if 0 <= bx < self.world_width and 0 <= by < self.world_height: |
|
for unit in self.buckets[bx][by]: |
|
if other.id == unit.id: |
|
continue |
|
if unit.dist(other) < closest_dist: |
|
closest_dist = unit.dist(other) |
|
closest_id = other.id |
|
return closest_id |
|
|
|
WORLD_WIDTH = 50 |
|
WORLD_HEIGHT = 60 |
|
COUNT = 1000 |
|
UNITS = [] |
|
for _ in range(COUNT): |
|
UNITS.append(Unit( |
|
x=random.randrange(WORLD_WIDTH), |
|
y=random.randrange(WORLD_HEIGHT), |
|
id=random.randrange(100000000) |
|
)) |
|
|
|
times = [] |
|
for i in range(10): |
|
start = time.time() |
|
for unit in UNITS: |
|
closest = get_closest_id(unit, UNITS) |
|
end = time.time() |
|
print('.') |
|
if i > 5: |
|
times.append(end - start) |
|
|
|
print('slow time mean:', statistics.mean(times), 'time std:', statistics.stdev(times)) |
|
|
|
times = [] |
|
h = SpatialHash(WORLD_WIDTH, WORLD_HEIGHT) |
|
for i in range(10): |
|
start = time.time() |
|
for unit in UNITS: |
|
closest = h.get_closest_id(unit) |
|
end = time.time() |
|
print('.') |
|
if i > 5: |
|
times.append(end - start) |
|
|
|
print('hash time mean:', statistics.mean(times), 'time std:', statistics.stdev(times)) |