Skip to content

Instantly share code, notes, and snippets.

@helton
Last active October 7, 2024 17:34
Show Gist options
  • Save helton/edded58fb723214891416ae259b037f0 to your computer and use it in GitHub Desktop.
Save helton/edded58fb723214891416ae259b037f0 to your computer and use it in GitHub Desktop.
Auto Unwrap decorator for Celery tasks
def auto_unwrap(func):
"""
Decorator to allow a celery task function to accept:
- Individual positional and keyword arguments,
- A single list argument (unpacked as positional arguments),
- A single dictionary argument (unpacked as keyword arguments).
"""
@wraps(func)
def wrapper(*args, **kwargs):
# Case 1: Single list argument, no kwargs
if len(args) == 1 and isinstance(args[0], list) and not kwargs:
return func(*args[0], **kwargs)
# Case 2: Single dict argument, no kwargs
elif len(args) == 1 and isinstance(args[0], dict) and not kwargs:
return func(**args[0])
# Case 3: Regular args and kwargs
else:
return func(*args, **kwargs)
return wrapper
@app.task(name="download")
@auto_unwrap
def download(uid: str, source: str, source_type: str):
if source_type in ["url", "s3.object"]:
file_name = source.split("/")[-1]
else:
file_name = source
base_name, _ = os.path.splitext(file_name)
return {
"uid": uid,
"source": f"s3://mybucket/{uid}/downloads/{base_name}/{file_name}",
"source_type": "s3.object"
}
@app.task(name="extract")
@auto_unwrap
def extract(uid: str, source: str, source_type: str):
file_name = source.split("/")[-1]
base_name, _ = os.path.splitext(file_name)
return {
"uid": uid,
"source": f"s3://mybucket/{uid}/extractions/{base_name}/{base_name}.txt",
"source_type": "s3.object"
}
@app.task(name="chunkenize")
@auto_unwrap
def chunkenize(uid: str, source: str, source_type: str):
file_name = source.split("/")[-1]
base_name, _ = os.path.splitext(file_name)
return [{
"uid": uid,
"source": f"s3://mybucket/{uid}/chunks/{base_name}/chunk_{i}.txt",
"source_type": "s3.folder"
} for i in range(5)]
@app.task(name="embedding")
@auto_unwrap
def embedding(uid: str, source: str, source_type: str):
file_name = source.split("/")[-1]
base_name, _ = os.path.splitext(file_name)
return [{
"uid": uid,
"source": f"s3://mybucket/{uid}/embeddings/{base_name}/embedding_n.json",
"source_type": "s3.folder"
}]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment