Last active
December 17, 2018 21:12
-
-
Save tweakimp/ac33581b1cca9c1da213ea52e88b00f0 to your computer and use it in GitHub Desktop.
A* pathfinding in starcraft
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
A* pathfinding starcraft in python | |
How can I optimize create_path? | |
""" | |
import heapq | |
import random | |
import time | |
import sc2 | |
from sc2 import Race, maps, run_game | |
from sc2.ids.ability_id import LIFT | |
from sc2.ids.unit_typeid import COMMANDCENTER | |
from sc2.player import Bot | |
from sc2.position import Point2, Point3 | |
def terrain_to_z_height(h): | |
return round(-100 + 200 * h / 255) | |
class PriorityQueue: | |
def __init__(self): | |
self.elements = [] | |
def empty(self): | |
return len(self.elements) == 0 | |
def put(self, item, priority): | |
heapq.heappush(self.elements, (priority, item)) | |
def get(self): | |
return heapq.heappop(self.elements)[1] | |
def heuristic(a, b): | |
(x1, y1) = a | |
(x2, y2) = b | |
return abs(x1 - x2) + abs(y1 - y2) | |
def cost(a, b): | |
dx = a.x - b.x | |
dy = a.y - b.y | |
return 1 if dx == 0 or dy == 0 else 1.414 | |
class astar(sc2.BotAI): | |
def __init__(self, dummy=True): | |
self.actions = [] | |
self.iter = None | |
self.directions = set() | |
self.dummy = dummy | |
self.random_points = [] | |
self.path = [] | |
self.all_pathable_points = None | |
async def on_step(self, iteration): | |
for base in self.units(COMMANDCENTER): | |
self.actions.append(base(LIFT)) | |
await self.do_actions(self.actions) | |
self.actions = [] | |
# dummy client just to lift his cc to clear all paths | |
if self.dummy: | |
return | |
self.iter = iteration | |
# lifted | |
if not self.units(COMMANDCENTER) and self.iter >= 30: | |
# wait until lifted, then calculate all pathable points | |
if not self.all_pathable_points: | |
self._game_info = await self._client.get_game_info() | |
pathing_grid = self._game_info.pathing_grid | |
self.all_pathable_points = { | |
Point2((x, y)) | |
for x in range(pathing_grid.width) | |
for y in range(pathing_grid.height) | |
if pathing_grid[Point2((x, y))] == 0 | |
} | |
start_time = time.time() | |
self.create_directions() | |
print("direction calculation time", time.time() - start_time) | |
start_time = time.time() | |
# self.random_points = list(random.sample(self.all_pathable_points, 2)) | |
# test with points that are far away from each other | |
self.random_points = [Point2((123, 19)), Point2((21, 143))] | |
self.create_path(self.random_points[0], self.random_points[1]) | |
await self.draw() | |
path_time = time.time() - start_time | |
# print("path calculation time", round(path_time, 2)) | |
# print("path length", len(self.path)) | |
if self.path: | |
print("time per node", round(path_time / len(self.path), 10)) | |
await self._client.send_debug() | |
# time.sleep(10) | |
# kill clients at the same iteration every time to profile the same task | |
if self.iter == 40: | |
await self._client.leave() | |
def create_directions(self): | |
self.directions = { | |
(point, near_point) | |
for point in self.all_pathable_points | |
for near_point in [ | |
Point2((point.x + a, point.y + b)) | |
for a, b in {(-1, -1), (-1, -0), (-1, 1), (0, -1), (0, 1), (1, -1), (1, -0), (1, 1)} | |
if Point2((point.x + a, point.y + b)) in self.all_pathable_points | |
] | |
} | |
def create_path(self, start, end): | |
frontier = PriorityQueue() | |
frontier.put(start, 0) | |
came_from = {} | |
cost_so_far = {} | |
came_from[start] = None | |
cost_so_far[start] = 0 | |
while not frontier.empty(): | |
current = frontier.get() | |
if current == end: | |
break | |
neighbors = [ | |
Point2((current.position.x + a, current.position.y + b)) | |
for a, b in {(-1, -1), (-1, 0), (-1, 1), (0, -1), (0, 1), (1, -1), (1, 0), (1, 1)} | |
if Point2((current.position.x + a, current.position.y + b)) in self.all_pathable_points | |
] | |
for next_node in neighbors: | |
new_cost = cost_so_far[current] + cost(current, next_node) | |
if next_node not in cost_so_far or new_cost < cost_so_far[next_node]: | |
cost_so_far[next_node] = new_cost | |
priority = new_cost + heuristic(end, next_node) | |
frontier.put(next_node, priority) | |
came_from[next_node] = current | |
current = end | |
self.path = [] | |
while current != start: | |
self.path.append(current) | |
current = came_from[current] | |
self.path.append(start) | |
self.path.reverse() | |
async def draw(self): | |
for point in self.path: | |
height = terrain_to_z_height(self.get_terrain_height(point)) | |
location3d1 = Point3((point.x - 0.1, point.y - 0.1, height + 0.1)) | |
location3d2 = Point3((point.x + 0.1, point.y + 0.1, height + 0.1)) | |
self._client.debug_box_out(location3d1, location3d2, Point3((255, 255, 255))) | |
run_game( | |
maps.get("KairosJunctionLE"), | |
[Bot(Race.Terran, astar(dummy=False)), Bot(Race.Terran, astar(dummy=True))], | |
realtime=False, | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment