Last active
May 10, 2021 15:39
-
-
Save azarnyx/7bda05dc2e95a3c47b7d353e85ed259c to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Specify that endpoint accept JSON | |
JSON_CONTENT_TYPE = 'application/json' | |
def predict_fn(input, model): | |
proba = model.predict_proba(input) | |
return json.dumps({ | |
"proba": str(list(proba[0])) | |
}) | |
def model_fn(model_dir): | |
clf = load(os.path.join(model_dir, 'sklearnclf.joblib')) | |
return clf | |
def input_fn(request_body, content_type=JSON_CONTENT_TYPE): | |
logger.info('Deserializing the input data.') | |
# process an jsonlines uploaded to the endpoint | |
if content_type == JSON_CONTENT_TYPE: | |
request_body = json.loads(request_body) | |
st = request_body["text"] | |
return get_embedding(st).reshape(1,-1) | |
raise Exception('Requested unsupported ContentType in content_type: {}'.format(content_type)) | |
def output_fn(prediction, accept=JSON_CONTENT_TYPE): | |
logger.info('Serializing the generated output.') | |
if accept == JSON_CONTENT_TYPE: return json.dumps(prediction), accept | |
raise Exception('Requested unsupported ContentType in Accept: {}'.format(accept)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment