Last active
September 7, 2022 15:45
-
-
Save jeffvestal/551cb07928dd01adec060521e21c3612 to your computer and use it in GitHub Desktop.
Elastic - NLP - Load HuggingFace Model with Zero Shot Example
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"name": " Elastic - NLP - Load HuggingFace Model with Zero Shot Example", | |
"provenance": [], | |
"collapsed_sections": [], | |
"authorship_tag": "ABX9TyOroRibhoorNRL0/+AfddoE", | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"language_info": { | |
"name": "python" | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/jeffvestal/551cb07928dd01adec060521e21c3612/-elastic-nlp-load-huggingface-model-with-zero-shot-example.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"Use this code to load a NLP model from Hugging Face for use inside Elastic's elasticsearch. \n", | |
"\n", | |
"You can set up a [free trial elasticsearch Deployment in Elastic Cloud](https://cloud.elastic.co/registration) and run the below code in [Google's Colab](https://colab.research.google.com) for free.\n", | |
"\n", | |
"Requires Elastic version 8.0+ with a platinum or enterprise license (or trial license)\n", | |
"\n", | |
"Example here is loading a [Zero Shot model](https://huggingface.co/typeform/distilbert-base-uncased-mnli)\n", | |
"\n", | |
"[Elastic NLP Model Support Docs](https://www.elastic.co/guide/en/machine-learning/current/ml-nlp-model-ref.html) \n", | |
"\n", | |
"[Code summarized from the eland docs](https://github.com/elastic/eland)\n", | |
"\n", | |
"Disclaimer: presented as is with no guarantee." | |
], | |
"metadata": { | |
"id": "6xoLDtS_6Df1" | |
} | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# Install eland and *elasticsearch*" | |
], | |
"metadata": { | |
"id": "Ly1f1P-l9ri8" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "rUedSzQW9FIF" | |
}, | |
"outputs": [], | |
"source": [ | |
"pip install eland" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"pip install elasticsearch" | |
], | |
"metadata": { | |
"id": "NK3Wx1I199yB" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"pip install transformers" | |
], | |
"metadata": { | |
"id": "cEfiiFXakzdP" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"pip install sentence_transformers" | |
], | |
"metadata": { | |
"id": "I20mDmJboKZw" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"pip install torch==1.11" | |
], | |
"metadata": { | |
"id": "uqcpWrbkBEB9" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"import getpass\n", | |
"from pathlib import Path\n", | |
"from eland.ml.pytorch import PyTorchModel\n", | |
"from eland.ml.pytorch.transformers import TransformerModel\n", | |
"from elasticsearch import Elasticsearch\n", | |
"from elasticsearch.client import MlClient\n" | |
], | |
"metadata": { | |
"id": "-dqhRCBUe1U-" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# Set Connection and Auth" | |
], | |
"metadata": { | |
"id": "r7nMIbHke37Q" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"es_url = getpass.getpass('Enter elasticsearch endpoint: ') # endpoint https://<esurl>:<port>\n", | |
"es_user = getpass.getpass('Enter username: ') # username\n", | |
"es_pass = getpass.getpass('Enter password: ') # password" | |
], | |
"metadata": { | |
"id": "SSGgYHome69o" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# Connect to Elastic and Load a Hugging Face Model" | |
], | |
"metadata": { | |
"id": "jL4VDnVp96lf" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"es = Elasticsearch(es_url, basic_auth=(es_user,es_pass))" | |
], | |
"metadata": { | |
"id": "I8mVJkKmetXo" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"[Supported `task_type` values](https://github.com/elastic/eland/blob/15a300728876022b206161d71055c67b500a0192/eland/ml/pytorch/transformers.py#*L41*)" | |
], | |
"metadata": { | |
"id": "QmZ1fkwYM5er" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Download a Hugging Face Zero Shot model directly from the model hub\n", | |
"\n", | |
"# https://huggingface.co/typeform/distilbert-base-uncased-mnli\n", | |
"#tm = TransformerModel(\"sentence-transformers/all-MiniLM-L12-v2\", \"text_embedding\")\n", | |
"tm = TransformerModel(\"distilbert-base-cased-distilled-squad\", \"question_answering\")" | |
], | |
"metadata": { | |
"id": "zPV3oFsKiYFL" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Export the model in a TorchScrpt representation which Elasticsearch uses\n", | |
"tmp_path = \"models\"\n", | |
"Path(tmp_path).mkdir(parents=True, exist_ok=True)\n", | |
"# model_path, config_path, vocab_path = tm.save(tmp_path) #pre 8.2.0\n", | |
"model_path, config, vocab_path = tm.save(tmp_path)" | |
], | |
"metadata": { | |
"id": "GsSpvvP-nbCK" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Import model into Elasticsearch\n", | |
"ptm = PyTorchModel(es, tm.elasticsearch_model_id())\n", | |
"# ptm.import_model(model_path, config_path, vocab_path) # pre 8.2.0\n", | |
"ptm.import_model(model_path=model_path, config_path=None, vocab_path=vocab_path, config=config) " | |
], | |
"metadata": { | |
"id": "Z4QD71Apnj4j" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# Deploy the model" | |
], | |
"metadata": { | |
"id": "oMGw3sk-pbaN" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# List models in elasticsearch\n", | |
"m = MlClient.get_trained_models(es, )\n", | |
"m.body" | |
], | |
"metadata": { | |
"id": "b4Wv8EJvpfZI" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Deploy the model\n", | |
"\n", | |
"# Model is listed -> 'model_id': 'typeform__distilbert-base-uncased-mnli'\n", | |
"model_id='distilbert-base-cased-distilled-squad'\n", | |
"\n", | |
"# start trained model deployment\n", | |
"s = MlClient.start_trained_model_deployment(es, model_id=model_id)\n", | |
"s.body\n", | |
"\n", | |
"# You can see model state in Kibana -> Machine Learning -> Model Management -> Trained Models" | |
], | |
"metadata": { | |
"id": "w5muJ1rLqvUW" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# Zero Shot Time!" | |
], | |
"metadata": { | |
"id": "6Hu2n4bmGYkG" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# future reference do not use yet\n", | |
"#z = MlClient.infer_trained_model_deployment(es, model_id =model_id, docs=docs, )" | |
], | |
"metadata": { | |
"id": "ZsWg7XPSGbiu" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Using requests until MlClient.infer_trained_model_deployment is updated to accept inference extra configs\n", | |
"import requests\n", | |
"from requests.auth import HTTPBasicAuth\n", | |
"import urllib.parse\n", | |
"\n", | |
"endpoint = '_ml/trained_models/%s/deployment/_infer' % model_id\n", | |
"url = urllib.parse.urljoin(es_url, endpoint)\n", | |
"\n", | |
"body = {\n", | |
" \"docs\": [\n", | |
" {\n", | |
" \"text_field\": \"Last week I upgraded my iOS version and ever since then my phone has been overheating whenever I use your app.\"\n", | |
" }\n", | |
" ],\n", | |
" \"inference_config\": {\n", | |
" \"zero_shot_classification\": {\n", | |
" \"labels\": [\n", | |
" \"mobile\",\n", | |
" \"website\",\n", | |
" \"billing\",\n", | |
" \"account access\"\n", | |
" ],\n", | |
" \"multi_label\": True\n", | |
" }\n", | |
" }\n", | |
"}\n", | |
"\n", | |
"resp = requests.post(url, auth=HTTPBasicAuth(es_user, es_pass), json=body)\n", | |
"r = resp.json()\n", | |
"print('Predicted value is: %s with a probability of %0.2f%%' % (r['predicted_value'], r['prediction_probability'] * 100))\n", | |
"print('=-=-=-=')\n", | |
"print('Full Probability output:')\n", | |
"for c in r['top_classes']:\n", | |
" print ('%s probability of %0.5f%%' % (c['class_name'], c['class_probability'] * 100))" | |
], | |
"metadata": { | |
"id": "tf9c-XkrQTM3" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Just to see the full doc\n", | |
"resp.json()" | |
], | |
"metadata": { | |
"id": "f3JRG4SeaESo" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment