Created
August 27, 2021 11:56
-
-
Save pablosjv/aaf49bd6246e8651ef2d7559244d628f to your computer and use it in GitHub Desktop.
Large Scale Pytorch Inference Pipeline: Spark vs Dask - Code Examples
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from enum import Enum | |
from dask.distributed import Client, LocalCluster, SpecCluster | |
from dask_yarn import YarnCluster | |
class ClusterType(Enum): | |
YARN = 'yarn' | |
LOCAL = 'local' | |
def get_cluster(cluster_type: ClusterType = ClusterType.YARN) -> SpecCluster: | |
cluster: SpecCluster = None | |
if cluster_type == ClusterType.LOCAL: | |
cluster = LocalCluster() | |
elif cluster_type == ClusterType.YARN: | |
cluster = YarnCluster.from_current() | |
else: | |
raise ValueError('Invalid CLUSTER_TYPE defined') | |
return cluster | |
# NOTE: initialize the dask cluster | |
cluster = get_cluster(ClusterType.LOCAL) | |
client = Client(cluster) | |
print(f"Dashboard link: {client.dashboard_link}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment