Last active
January 27, 2021 16:14
-
-
Save edoakes/84fb783b05b5a256a3afdbefbd655cdd to your computer and use it in GitHub Desktop.
Ray for training + Ray Serve for inference
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
parser = argparse.ArgumentParser() | |
parser.add_argument("train_s3_urls") | |
parser.add_argument("inference_s3_urls") | |
parser.add_argument("output_path") | |
import ray | |
from ray import serve | |
ray.init(address="auto") | |
@ray.remote(num_gpus=4) | |
class Trainer: | |
def __init__(self, initial_weights): | |
# Configure this to use GPUs. | |
self.model = init_model(initial_weights) | |
def train(self, s3_urls): | |
# Fetch s3 data. If the same images are shared across multiple | |
# processes, it's likely better to fetch them on the driver and | |
# put them in the object store to optimize. | |
data = self.fetch_s3_images(s3_urls) | |
self.model.train(data) | |
return self.model.weights | |
class Servable: | |
def __init__(self, trained_weights): | |
self.model = init_model(trained_weights) | |
@serve.accept_batch | |
def __call__(self, s3_urls): | |
data = self.fetch_s3_images(s3_urls) | |
return self.model.inference(data) | |
def main(args): | |
# Can do this with multiple model types or duplicate the Trainer. | |
trainer = Trainer.remote(initial_weights) | |
result_weights = ray.get(trainer.train.remote(args.train_s3_urls)) | |
client = serve.start(http_host=None) | |
serve.create_backend( | |
"model1", Servable, result_weights, config=serve.BackendConfig(num_replicas=20)) | |
serve.create_endpoint("model1", backend="model1") | |
handle = serve.get_handle("model1") | |
refs = [handle.remote(s3_url) for s3_url in args.inference_urls] | |
for result in ray.get(refs): | |
write_to_s3(args.output_path, result) | |
if __name__ == "__main__": | |
main(parser.parse_args()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Ok, we found the problem. The
Trainer
was defined using thedataclass
decorator and we had to set theray.remote
as the first one. After that it worked!