Last active
June 17, 2022 15:24
-
-
Save dmigo/12e31a6d12e140a90a21ee715cf5f994 to your computer and use it in GitHub Desktop.
Runs haystack indexing on ray
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import ray | |
import s3fs | |
import yaml | |
from ray import serve | |
from haystack.pipelines.base import Pipeline | |
RAY_ADDRESS=<HEAD_NODE_ADDRESS> | |
RAY_PORT='10001' | |
RAY_SERVE_PORT='8000' | |
RAY_NAMESPACE='default' | |
DATA_PATH=<S3_PATH> | |
PIPELINES_PATH=<S3_PATH> | |
class RayConnection: | |
def __init__(self): | |
address = f"ray://{RAY_ADDRESS}:{RAY_PORT}" | |
namespace = RAY_NAMESPACE | |
try: | |
ray.init(address=address, namespace=namespace) | |
serve.start(detached=True) | |
except Exception as error: | |
raise RayUnavailableException(f'Connection to ray failed due to "{error}".') | |
def __enter__(self): | |
return self | |
def __exit__(self, typ, value, traceback): | |
ray.shutdown() | |
serve.shutdown() | |
def read_config(s3prefix: str, s3file: str): | |
fs = s3fs.S3FileSystem() | |
with fs.open(f"{s3prefix}/{s3file}") as f: | |
return yaml.safe_load(f) | |
@ray.remote | |
def index_docs(pipeline: str, data: str): | |
fs = s3fs.S3FileSystem() | |
config = read_config(PIPELINES_PATH, pipeline) | |
Pipeline.load_from_config(config, pipeline_name='indexing') | |
return {'data':data,'config':config} | |
def main(): | |
with RayConnection(): | |
index_ref = index_docs.remote('sparse.yaml', 'case_corpus_10.jsonl') | |
result = ray.get(index_ref) | |
print(result) | |
if __name__ == "__main__": | |
main() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment