Created
November 29, 2023 02:06
-
-
Save pryce-turner/a84f4cbeea1cf3923f625e4407cf46f0 to your computer and use it in GitHub Desktop.
Node caching client side implementation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import shutil | |
import hashlib | |
from time import sleep | |
from typing import List | |
from random import randint | |
from pathlib import Path | |
from flytekit import task, workflow, dynamic | |
from flytekit.types.file import FlyteFile | |
from flytekitplugins.pod import Pod | |
from kubernetes.client.models import ( | |
V1PodSpec, | |
V1Volume, | |
V1Container, | |
V1VolumeMount, | |
V1PersistentVolumeClaimVolumeSource, | |
) | |
pod_mount_path = "/nodecache" | |
vol_name = "task-cache-vol" | |
persist_local_ps = V1PodSpec( | |
volumes=[ | |
V1Volume( | |
name=vol_name, | |
persistent_volume_claim=V1PersistentVolumeClaimVolumeSource( | |
claim_name="task-cache-pvc" | |
), | |
) | |
], | |
containers=[ | |
V1Container( | |
name="primary", | |
image="docker.io/rwgrim/docker-noop", | |
image_pull_policy="IfNotPresent", | |
volume_mounts=[ | |
V1VolumeMount( | |
name=vol_name, | |
sub_path="task_cache", | |
mount_path=pod_mount_path, | |
), | |
], | |
) | |
], | |
) | |
class CacheFile: | |
def __init__(self, ff: FlyteFile): | |
self.ff = ff | |
self.fname = Path(ff.path).name | |
self.path = None | |
try: | |
os.listdir(pod_mount_path) | |
except FileNotFoundError as e: | |
raise FileNotFoundError( | |
f"The default mount path ({pod_mount_path}) does not exist. " | |
f"Did you use the appropriate pod spec in your task config?" | |
) from e | |
# More informative error when a FlyteFile is initialized without a downloader | |
assert 'noop' not in self.ff._downloader.__str__(), ( | |
"FlyteFile initialized with no downloader. " | |
"Was it not created at the task boundary?" | |
) | |
def check_cache(self) -> str: | |
self.path = Path(pod_mount_path).joinpath(self.fname) | |
if self.path.exists(): | |
return 'HIT' | |
else: | |
# Get lockfile path and sleep until another task caches | |
lock = self.path.with_name(self.path.name + ".caching") | |
# sleep(randint(0, 5)) | |
if not lock.exists(): | |
# Cache file and cleanup | |
lock.touch() | |
self.ff.download() | |
shutil.move(self.ff.path, self.path) | |
self.ff.path = self.path | |
lock.unlink() | |
return 'CACHED' | |
else: | |
while lock.exists(): | |
sleep(5) | |
return 'CACHED_OTHER' | |
@task(task_config=Pod(pod_spec=persist_local_ps)) | |
def scratch(ff: FlyteFile) -> str: | |
cf = CacheFile(ff) | |
stat = cf.check_cache() | |
with open(cf.path, "r") as file: | |
line_count = 0 | |
for line in file: | |
print(line.strip()) # Print each line, removing newline characters | |
line_count += 1 | |
if line_count >= 50: | |
break | |
return stat | |
@task(task_config=Pod(pod_spec=persist_local_ps)) | |
def drain_cache(ff: FlyteFile, s: List[str]): | |
cf = CacheFile(ff) | |
cf.check_cache() | |
os.remove(cf.path) | |
@dynamic | |
def wf(): | |
stats = [] | |
for i in ["s3://my-s3-bucket/my-data/refs/GRCh38.fasta" for _ in range(10)]: | |
stat = scratch(ff=i) | |
stats.append(stat) | |
drain_cache(ff="s3://my-s3-bucket/my-data/refs/GRCh38.fasta", s=stats) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment