Last active
January 27, 2021 16:14
-
-
Save edoakes/84fb783b05b5a256a3afdbefbd655cdd to your computer and use it in GitHub Desktop.
Ray for training + Ray Serve for inference
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
parser = argparse.ArgumentParser() | |
parser.add_argument("train_s3_urls") | |
parser.add_argument("inference_s3_urls") | |
parser.add_argument("output_path") | |
import ray | |
from ray import serve | |
ray.init(address="auto") | |
@ray.remote(num_gpus=4) | |
class Trainer: | |
def __init__(self, initial_weights): | |
# Configure this to use GPUs. | |
self.model = init_model(initial_weights) | |
def train(self, s3_urls): | |
# Fetch s3 data. If the same images are shared across multiple | |
# processes, it's likely better to fetch them on the driver and | |
# put them in the object store to optimize. | |
data = self.fetch_s3_images(s3_urls) | |
self.model.train(data) | |
return self.model.weights | |
class Servable: | |
def __init__(self, trained_weights): | |
self.model = init_model(trained_weights) | |
@serve.accept_batch | |
def __call__(self, s3_urls): | |
data = self.fetch_s3_images(s3_urls) | |
return self.model.inference(data) | |
def main(args): | |
# Can do this with multiple model types or duplicate the Trainer. | |
trainer = Trainer.remote(initial_weights) | |
result_weights = ray.get(trainer.train.remote(args.train_s3_urls)) | |
client = serve.start(http_host=None) | |
serve.create_backend( | |
"model1", Servable, result_weights, config=serve.BackendConfig(num_replicas=20)) | |
serve.create_endpoint("model1", backend="model1") | |
handle = serve.get_handle("model1") | |
refs = [handle.remote(s3_url) for s3_url in args.inference_urls] | |
for result in ray.get(refs): | |
write_to_s3(args.output_path, result) | |
if __name__ == "__main__": | |
main(parser.parse_args()) |
Ok, we found the problem. The Trainer
was defined using the dataclass
decorator and we had to set the ray.remote
as the first one. After that it worked!
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hey Edward, we're having trouble when importing the Trainer class. We make the import using
importlib
like this:The error we get is
'ActorClass(AdaptiveMLP)' object has no attribute '__mro__'
What would you recommend here? Maybe import the module in some other way?