Created
December 21, 2020 01:51
-
-
Save ian-whitestone/42fc0d142cc7f63c79f1e693d58a13a9 to your computer and use it in GitHub Desktop.
Code for the Dask cluster on demand blog post
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import logging | |
import time | |
from dataclasses import dataclass | |
from typing import Any, Optional, Tuple | |
import requests | |
import yaml | |
from distributed import Client | |
from distributed.security import Security | |
from googleapiclient import discovery | |
LOGGER = logging.getLogger(__name__) | |
# Define some constants here, directly in the class, | |
# or import from another module | |
GCP_PROJECT_ID: str = "default project id" | |
GCP_PROJECT_NUMBER: str = "default project number" | |
GCP_CLUSTER_ZONE: str = "default cluster zone" | |
GCP_INSTANCE_NAME: str = "default instance name" | |
GCP_DOCKER_IMAGE: str = "default docker image" | |
DASK_CERT_FILEPATH: str = "path to default dask certificate you want to use" | |
DASK_KEY_FILEPATH: str = "path to default dask key you want to use" | |
MACHINE_TYPE: str = "e2-standard-16", | |
# Defaults for single node Dask cluster docker image | |
# see https://gist.github.com/ian-whitestone/d3b876e77743923b112d7d004d86480c | |
# or https://ianwhitestone.work/single-node-dask-cluster-on-gcp/ for more details | |
NUM_WORKERS: int = 16, | |
THREADS_PER_WORKER: int = 1, | |
MEMORY_PER_WORKER_GB: float = 4, | |
@dataclass | |
class Cluster: | |
gcp_project_id: str = GCP_PROJECT_ID | |
gcp_project_number: str = GCP_PROJECT_NUMBER | |
gcp_cluster_zone: str = GCP_CLUSTER_ZONE | |
gcp_instance_name: str = GCP_INSTANCE_NAME | |
gcp_docker_image: str = GCP_DOCKER_IMAGE | |
# Only needed if you're using a Dask cluster with SSL security | |
# See https://ianwhitestone.work/dask-cluster-security/ for more details | |
dask_cert_filepath: str = DASK_CERT_FILEPATH | |
dask_key_filepath: str = DASK_KEY_FILEPATH | |
machine_type: str = MACHINE_TYPE | |
num_workers: int = NUM_WORKERS | |
threads_per_worker: int = THREADS_PER_WORKER | |
memory_per_worker_gb: int = MEMORY_PER_WORKER_GB | |
def __post_init__(self): | |
self._validate_machine_type() | |
self.compute = discovery.build("compute", "v1", cache_discovery=False) | |
self.disk_image_name, self.disk_image_link = self._get_latest_image() | |
self.create() | |
self.cluster_host_ip_address = self._get_cluster_ip_address() | |
self._wait_until_cluster_is_ready() | |
self.client = self.create_client() | |
def _validate_machine_type(self): | |
gcp_machine_types = { | |
# shared core | |
"e2-micro": {"vCPU": 2, "memory_gb": 1}, | |
"e2-small": {"vCPU": 2, "memory_gb": 2}, | |
"e2-medium": {"vCPU": 2, "memory_gb": 4}, | |
# standard | |
"e2-standard-2": {"vCPU": 2, "memory_gb": 8}, | |
"e2-standard-4": {"vCPU": 4, "memory_gb": 16}, | |
"e2-standard-8": {"vCPU": 8, "memory_gb": 32}, | |
"e2-standard-16": {"vCPU": 16, "memory_gb": 64}, | |
"e2-standard-32": {"vCPU": 32, "memory_gb": 128}, | |
# high memory | |
"e2-highmem-2": {"vCPU": 2, "memory_gb": 16}, | |
"e2-highmem-4": {"vCPU": 4, "memory_gb": 32}, | |
"e2-highmem-8": {"vCPU": 8, "memory_gb": 64}, | |
"e2-highmem-16": {"vCPU": 16, "memory_gb": 128}, | |
# high compute | |
"e2-highcpu-2": {"vCPU": 2, "memory_gb": 2}, | |
"e2-highcpu-4": {"vCPU": 4, "memory_gb": 4}, | |
"e2-highcpu-8": {"vCPU": 8, "memory_gb": 8}, | |
"e2-highcpu-16": {"vCPU": 16, "memory_gb": 16}, | |
"e2-highcpu-32": {"vCPU": 32, "memory_gb": 32}, | |
} | |
# Example custom machine spec: e2-custom-32-49152 | |
if "custom" in self.machine_type: | |
parts = self.machine_type.split("-") | |
if len(parts) != 4: # TODO: replace with regex validation | |
raise ValueError( | |
"Custom machine type must be formatted like 'e2-custom-32-49152'" | |
) | |
num_cpus = int(parts[2]) | |
memory = int(parts[3]) | |
if memory % 256: | |
raise ValueError("Memory must be a multiple of 256") | |
if num_cpus < 2 or (num_cpus % 2): | |
raise ValueError("# of CPUs must be greater than 2 and a multiple of 2") | |
return | |
if self.machine_type not in gcp_machine_types: | |
raise ValueError( | |
f"'{self.machine_type}' is not a valid machine type. " | |
f"Expecting one of {list(gcp_machine_types.keys())}" | |
) | |
num_cores_available = gcp_machine_types[self.machine_type]["vCPU"] | |
if self.num_workers > num_cores_available: | |
raise ValueError( | |
f"{self.machine_type} has {num_cores_available} cores available and " | |
f"you requested {self.num_workers}. Try specifying a machine_type with " | |
"more vCPUs or reduce num_workers." | |
) | |
def _get_latest_image(self): | |
""" | |
https://googleapis.github.io/google-api-python-client/docs/dyn/compute_v1.images.html#getFromFamily | |
Returns the latest image that is part of an image family and is not deprecated. | |
""" | |
image_response = ( | |
self.compute.images() | |
.getFromFamily(project="cos-cloud", family="cos-stable") | |
.execute() | |
) | |
return image_response["name"], image_response["selfLink"] | |
def _get_cluster_ip_address(self): | |
instances_list = ( | |
self.compute.instances() | |
.list( | |
project=self.gcp_project_id, | |
zone=self.gcp_cluster_zone, | |
filter=f"name = {self.gcp_instance_name}", | |
) | |
.execute() | |
) | |
if not instances_list.get("items"): | |
raise Exception("Instance not found") | |
if len(instances_list.get("items", [])) > 1: | |
raise Exception("More than 1 instance returned with search criteria") | |
return instances_list["items"][0]["networkInterfaces"][0]["accessConfigs"][0][ | |
"natIP" | |
] | |
def _wait_until_cluster_is_ready(self): | |
cluster_url = f"http://{self.cluster_host_ip_address}:8787/" | |
LOGGER.info(f"Waiting until cluster {cluster_url} is ready") | |
while True: | |
try: | |
r = requests.get(cluster_url) | |
if r.ok and "dask" in r.text.lower(): | |
LOGGER.info("Cluster is ready 🟢") | |
break | |
except ConnectionError: | |
time.sleep(30) | |
def _wait_for_operation(self, operation_name: str): | |
while True: | |
result = ( | |
self.compute.zoneOperations() | |
.get( | |
project=self.gcp_project_id, | |
zone=self.gcp_cluster_zone, | |
operation=operation_name, | |
) | |
.execute() | |
) | |
if result["status"] == "DONE": | |
if "error" in result: | |
raise Exception(result["error"]) | |
return | |
time.sleep(1) | |
@property | |
def gce_container_spec(self): | |
container_spec = { | |
"spec": { | |
"containers": [ | |
{ | |
"name": self.gcp_instance_name, | |
"image": self.gcp_docker_image, | |
"env": [ | |
{ | |
"name": "MEMORY_PER_WORKER", | |
"value": f"{self.memory_per_worker_gb}", | |
}, | |
{ | |
"name": "THREADS_PER_WORKER", | |
"value": f"{self.threads_per_worker}", | |
}, | |
{"name": "NUM_WORKERS", "value": f"{self.num_workers}"}, | |
], | |
"stdin": False, | |
"tty": False, | |
} | |
], | |
"restartPolicy": "Always", | |
} | |
} | |
return yaml.dump(container_spec) | |
@property | |
def machine_type_full_name(self): | |
return ( | |
f"projects/{self.gcp_project_id}/zones/" | |
f"{self.gcp_cluster_zone}/machineTypes/{self.machine_type}" | |
) | |
@property | |
def instance_config(self): | |
return { | |
"kind": "compute#instance", | |
"name": self.gcp_instance_name, | |
"zone": self.gcp_cluster_zone, | |
"machineType": self.machine_type_full_name, | |
"metadata": { | |
"kind": "compute#metadata", | |
"items": [ | |
{ | |
"key": "gce-container-declaration", | |
"value": self.gce_container_spec, | |
}, | |
{"key": "google-logging-enabled", "value": "true"}, | |
], | |
}, | |
"tags": {"items": ["http-server"]}, | |
"disks": [ | |
{ | |
"boot": True, | |
"autoDelete": True, | |
"initializeParams": {"sourceImage": self.disk_image_link}, | |
} | |
], | |
# Specify a network interface with NAT to access the public | |
# internet. | |
"networkInterfaces": [ | |
{ | |
"network": "global/networks/default", | |
"accessConfigs": [ | |
{"type": "ONE_TO_ONE_NAT", "name": "External NAT"} | |
], | |
} | |
], | |
"labels": {"container-vm": self.disk_image_name}, | |
"serviceAccounts": [ | |
{ | |
"email": f"{self.gcp_project_number}[email protected]", # noqa | |
"scopes": [ | |
"https://www.googleapis.com/auth/devstorage.read_only", | |
"https://www.googleapis.com/auth/logging.write", | |
"https://www.googleapis.com/auth/monitoring.write", | |
"https://www.googleapis.com/auth/servicecontrol", | |
"https://www.googleapis.com/auth/service.management.readonly", | |
"https://www.googleapis.com/auth/trace.append", | |
], | |
} | |
], | |
} | |
def create(self): | |
LOGGER.info("Creating new compute instance") | |
operation = ( | |
self.compute.instances() | |
.insert( | |
project=self.gcp_project_id, | |
zone=self.gcp_cluster_zone, | |
body=self.instance_config, | |
) | |
.execute() | |
) | |
self._wait_for_operation(operation["name"]) | |
def create_client(self): | |
cluster_host_url = f"tls://{self.cluster_host_ip_address}:8786" | |
LOGGER.info(f"Connecting new client to {cluster_host_url}") | |
sec = Security( | |
tls_ca_file=self.dask_cert_filepath, | |
tls_client_cert=self.dask_cert_filepath, | |
tls_client_key=self.dask_key_filepath, | |
require_encryption=True, | |
) | |
return Client(cluster_host_url, security=sec) | |
def teardown(self): | |
LOGGER.info("Shutting down client & tearing down cluster") | |
self.client.close() | |
operation = ( | |
self.compute.instances() | |
.delete( | |
project=self.gcp_project_id, | |
zone=self.gcp_cluster_zone, | |
instance=self.gcp_instance_name, | |
) | |
.execute() | |
) | |
LOGGER.info("Waiting for teardown to finish...") | |
self._wait_for_operation(operation["name"]) | |
def inspect_requires_cluster_function(func) -> Tuple[int, Optional[int]]: | |
""" | |
Validate a function decorated with @requires_cluster | |
has client specified as an arg or kwarg. | |
Return the position of the client argument in func. | |
""" | |
func_sig = inspect.signature(func) | |
client_arg_position = None | |
teardown_cluster_arg_position = None | |
for param_pos, param in enumerate(func_sig.parameters.values()): | |
if param.name == "client": | |
client_arg_position = param_pos | |
if param.name == "teardown_cluster": | |
teardown_cluster_arg_position = param_pos | |
if client_arg_position is None: | |
raise ValueError( | |
"Functions using @requires_cluster must accept client as an arg or kwarg " | |
"For example: \n" | |
""" | |
@requires_cluster | |
def do_stuff(a, b, client=None): | |
# do stuff with a, b and client | |
""" | |
) | |
return client_arg_position, teardown_cluster_arg_position | |
def requires_cluster( | |
num_workers: int = NUM_WORKERS, | |
threads_per_worker: int = THREADS_PER_WORKER, | |
memory_per_worker_gb: float = MEMORY_PER_WORKER, | |
machine_type: str = MACHINE_TYPE, | |
gcp_instance_name: str = GCP_INSTANCE_NAME, | |
gcp_cluster_zone: str = GCP_CLUSTER_ZONE, | |
teardown_cluster=True, | |
): | |
""" | |
A decorator to automatically provide a function with a ready to use | |
dask cluster client | |
""" | |
def decorator(func): | |
( | |
client_arg_position, | |
teardown_cluster_arg_position, | |
) = inspect_requires_cluster_function(func) | |
def wrapper(*args, **kwargs): | |
cluster = None | |
client_provided_in_args = False | |
for arg in args: | |
if isinstance(arg, Client): | |
client_provided_in_args = True | |
break | |
# When client is provided, just run the function as is | |
if isinstance(kwargs.get("client"), Client) or client_provided_in_args: | |
return func(*args, **kwargs) | |
if kwargs.get("teardown_cluster") is not None: | |
if not isinstance(kwargs["teardown_cluster"], bool): | |
raise ValueError("Value of teardown_cluster must be a boolean") | |
_teardown_cluster = kwargs["teardown_cluster"] | |
elif ( | |
teardown_cluster_arg_position is not None | |
and len(args) > teardown_cluster_arg_position | |
): | |
if not isinstance(args[teardown_cluster_arg_position], bool): | |
raise ValueError("Value of teardown_cluster must be a boolean") | |
_teardown_cluster = args[teardown_cluster_arg_position] | |
else: | |
_teardown_cluster = teardown_cluster | |
try: | |
cluster = Cluster( | |
num_workers=num_workers, | |
threads_per_worker=threads_per_worker, | |
memory_per_worker_gb=memory_per_worker_gb, | |
machine_type=machine_type, | |
gcp_instance_name=gcp_instance_name, | |
gcp_cluster_zone=gcp_cluster_zone, | |
) | |
# update the args/kwargs with the newly created client | |
new_args = [] | |
for i, arg in enumerate(args): | |
if i == client_arg_position: | |
new_args.append(cluster.client) | |
else: | |
new_args.append(arg) | |
if len(args) <= client_arg_position: | |
# client was not passed in as an arg, so update the kwargs | |
kwargs["client"] = cluster.client | |
return func(*new_args, **kwargs) | |
finally: | |
if _teardown_cluster: | |
if cluster is not None: | |
cluster.teardown() | |
else: | |
LOGGER.info( | |
f"Leaving cluster running at {cluster.cluster_host_ip_address}" | |
) | |
return wrapper | |
return decorator |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment