Skip to content

Instantly share code, notes, and snippets.

@cybersamx
Last active October 15, 2024 17:40
Show Gist options
  • Save cybersamx/7220b9522beb9454676672ff1102832f to your computer and use it in GitHub Desktop.
Save cybersamx/7220b9522beb9454676672ff1102832f to your computer and use it in GitHub Desktop.
Execute Flyte Tasks Concurrently using map_task and eager
# This script performs the same thing as flyte-consume-api-maptask.py but it runs
# the "tasks" concurrently within a pod using the eager task.
from flytekit import task
from flytekit.experimental import eager
from dataclasses import dataclass
from dataclasses_json import dataclass_json
import requests
@dataclass_json
@dataclass
class Name:
title: str
first: str
last: str
@dataclass_json
@dataclass
class User:
gender: str
email: str
name: Name
@dataclass_json
@dataclass
class Results:
results: list[User]
def is_error(status_code):
return status_code < 200 or status_code >= 300
@task
def fetch_randomusers(page: int, size: int, seed: str) -> list[User]:
# Randomuser.me is a free RESTful API site.
# We call the endpoint to get a paginated data result.
url = f'https://randomuser.me/api/?page={page}&results={size}&seed={seed}'
session = requests.session()
response = session.get(url)
print(f'Executing page {page}')
if is_error(response.status_code):
print(f'error: {response.status_code}')
raise
res: Results = Results.from_json(response.content)
print(f'Results from page {page}: {res}')
return res.results
@eager
async def eager_workflow() -> list[list[User]]:
users = []
for page in range(1, 3):
res = await fetch_randomusers(page=page, size=5, seed='abc')
users.append(res)
# Not the most efficient code since this a sample code.
# Let's flatten the users list, which is a list of list.
flatten_users = [u for uu in users for u in uu]
return flatten_users
if __name__ == "__main__":
print(f'Aggregated results: {eager_workflow()}')
# Break a fetch request to randomuser.me into tasks that can be run
# concurrenlty on different pods using the map_task.
import functools
from flytekit import map_task, task, workflow
from dataclasses import dataclass
from dataclasses_json import dataclass_json
import requests
# JSON schema
# (root)
# └─ results
# └─ user
# ├─ gender
# ├─ email
# └─ name
# ├─ title
# ├─ first
# └─ last
@dataclass_json
@dataclass
class Name:
title: str
first: str
last: str
@dataclass_json
@dataclass
class User:
gender: str
email: str
name: Name
@dataclass_json
@dataclass
class Results:
results: list[User]
def is_error(status_code):
return status_code < 200 or status_code >= 300
@task
def fetch_randomusers(page: int, size: int, seed: str) -> list[User]:
# Randomuser.me is a free RESTful API site.
# We call the endpoint to get a paginated data result.
url = f'https://randomuser.me/api/?page={page}&results={size}&seed={seed}'
session = requests.session()
response = session.get(url)
print(f'Executing page {page}')
if is_error(response.status_code):
print(f'error: {response.status_code}')
raise
res: Results = Results.from_json(response.content)
print(f'Results from page {page}: {res}')
return res.results
@workflow
def map_workflow() -> list[list[User]]:
partial_task = functools.partial(fetch_randomusers, size=5, seed='abc')
users = map_task(partial_task)(page=[1, 2, 3])
return users
if __name__ == "__main__":
print(f'Aggregated results: {map_workflow()}')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment