Skip to content

Instantly share code, notes, and snippets.

@ebernhardson
Last active February 7, 2019 06:13
Show Gist options
  • Save ebernhardson/ead904851dc3118234ee14d692ff2461 to your computer and use it in GitHub Desktop.
Save ebernhardson/ead904851dc3118234ee14d692ff2461 to your computer and use it in GitHub Desktop.
Tensorflow in SWAP
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Install python dependencies\n",
"===\n",
"And package up a zip file to send to executors."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!pip install tensorflow tensorflow_hub"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!cd venv; zip -qur ../spark_venv.zip ."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Download ELMo\n",
"==="
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"# Make sure tensorflow downloads to a place we expect\n",
"os.environ['TFHUB_CACHE_DIR'] = os.path.join(os.getcwd(), 'tf_hub')\n",
"\n",
"import tensorflow_hub as hub\n",
"import tensorflow as tf\n",
"\n",
"session = tf.Session()\n",
"embed = hub.Module(\"https://tfhub.dev/google/elmo/2\")\n",
"session.run(tf.global_variables_initializer())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Quick verification embedding works locally"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"embeddings = embed([\n",
"\"The quick brown fox jumps over the lazy dog.\",\n",
"\"I am a sentence for which I would like to get its embedding\"])\n",
"\n",
"print(session.run(embeddings))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Bundle ELMo up to send to executors as well"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!cd tf_hub; zip -qur ../tf_hub.zip ."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Start up spark\n",
"==="
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import findspark\n",
"findspark.init('/usr/lib/spark2')\n",
"from pyspark.ml.linalg import Vectors, VectorUDT\n",
"from pyspark.sql import SparkSession, functions as F, types as T\n",
"import os\n",
"\n",
"os.environ['PYSPARK_SUBMIT_ARGS'] = '--archives spark_venv.zip#venv,tf_hub.zip#tf_hub pyspark-shell'\n",
"os.environ['PYSPARK_PYTHON'] = 'venv/bin/python'\n",
"\n",
"spark = (\n",
" SparkSession.builder\n",
" .appName('tensorflow: ELMo embedding')\n",
" .master('yarn')\n",
" .config('spark.executor.memoryOverhead', '2g')\n",
" .config('spark.dynamicAllocation.maxExecutors', 50)\n",
" .getOrCreate()\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Function to perform embedding on remote executors\n",
"==="
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def batch(iterable, n):\n",
" batch = []\n",
" for x in iterable:\n",
" batch.append(x)\n",
" if len(x) == n:\n",
" yield batch\n",
" batch = []\n",
" if batch:\n",
" yield batch\n",
"\n",
"def embed_partition(rows, col):\n",
" import tensorflow as tf\n",
" import tensorflow_hub as hub\n",
" os.environ['TFHUB_CACHE_DIR'] = os.path.join(os.getcwd(), 'tf_hub')\n",
" sentences = [row[col] for row in rows]\n",
" with tf.Session() as session:\n",
" embed = hub.Module(\"https://tfhub.dev/google/elmo/2\")\n",
" session.run(tf.global_variables_initializer())\n",
" for batch_sentences in batch(sentences, 100):\n",
" embeddings = session.run(embed(batch_sentences))\n",
" for sentence, embedding in zip(batch_sentences, embeddings):\n",
" yield sentence, Vectors.dense(embedding)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Read some sentences and run them\n",
"==="
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from functools import partial\n",
"\n",
"rdd_embedding = (\n",
" spark.table('ebernhardson.cirrus2hive')\n",
" .where(F.col('wikiid') == 'enwiki')\n",
" .where(F.col('dump_date') == '20190121')\n",
" .sample(withReplacement=False, fraction=0.00001)\n",
" .select(F.explode(F.col('heading')).alias('heading'))\n",
" .rdd.mapPartitions(partial(embed_partition, col='heading'))\n",
")\n",
"\n",
"df_embedding = (\n",
" spark.createDataFrame(rdd_embedding, T.StructType([\n",
" T.StructField('heading', T.StringType(), nullable=False),\n",
" T.StructField('embedding', VectorUDT(), nullable=False),\n",
" ]))\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"df_embedding.show()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment