Created
May 4, 2019 17:24
-
-
Save att288/17fa403973d6f718183b44f9eb21ab04 to your computer and use it in GitHub Desktop.
load_multiple_keras_models_django.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
################################################### | |
# settings.py/the file where to load the models | |
################################################### | |
def load_model_from_path(path): | |
graph = tf.get_default_graph() | |
model = load_model(path) | |
return graph, model | |
def load_all_models(): | |
global gModelObjs # each object is a tuple of graph, model | |
gModelObjs = dict() | |
gModelObjs = { | |
'model_1': load_model_from_path('path_for_model_1'), | |
'model_2': load_model_from_path('path_for_model_2'), | |
'model_3': load_model_from_path('path_for_model_3'), | |
'model_4': load_model_from_path('path_for_model_4'), | |
} | |
######### | |
# view.py | |
######### | |
@api_view(['POST']) | |
def score_segment(request): | |
# We no longer need to load the model here. It's already preloaded. | |
# graph, model = _load_model_from_path('path_to_keras_model') | |
graph, model = gModelObjs('model_1') # if model_1 is used | |
df_segment = pd.DataFrame(request.data) | |
try: | |
predictions = _score_segment_by_model(df_segment, graph, model) | |
except Exception as error: | |
return Response({'message':str(error)}, status=status.HTTP_400_BAD_REQUEST) | |
return Response({ | |
'predictions': predictions | |
}) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment