Skip to content

Instantly share code, notes, and snippets.

@att288
Created May 4, 2019 17:06
Show Gist options
  • Save att288/7e4bcfd16d3197171e936117c4a78828 to your computer and use it in GitHub Desktop.
Save att288/7e4bcfd16d3197171e936117c4a78828 to your computer and use it in GitHub Desktop.
load_keras_django_naive.py
import json
import os
import numpy as np
import pandas as pd
import tensorflow as tf
from keras import backend as K
from keras.models import load_model
from rest_framework.decorators import api_view
from rest_framework.response import Response
@api_view(['POST'])
def score_segment(request):
graph, model = _load_model_from_path('path_to_keras_model')
df_segment = pd.DataFrame(request.data)
try:
predictions = _score_segment_by_model(df_segment, graph, model)
except Exception as error:
return Response({'message':str(error)}, status=status.HTTP_400_BAD_REQUEST)
return Response({
'predictions': predictions
})
def _load_model_from_path(path):
graph = tf.get_default_graph()
model = load_model(path) # keras function
return graph, model
def _score_segment_by_model(df_segment, graph, model):
X = preprocess_df_segment() # This is the preprocessing function, depending on your business need. I don't provide it here
try:
with graph.as_default():
predictions = model.predict(X)
return predictions
except Exception as err:
raise(err)
return None
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment