Skip to content

Instantly share code, notes, and snippets.

@austinzh
Last active August 5, 2021 01:18
Show Gist options
  • Save austinzh/361f7758b56dfa41b538c9b0c16feeb5 to your computer and use it in GitHub Desktop.
Save austinzh/361f7758b56dfa41b538c9b0c16feeb5 to your computer and use it in GitHub Desktop.
shared_embedding.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "shared_embedding.ipynb",
"provenance": [],
"collapsed_sections": [],
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/austinzh/361f7758b56dfa41b538c9b0c16feeb5/untitled86.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"metadata": {
"id": "E1e4lKDUo1nA"
},
"source": [
"import tensorflow as tf\n"
],
"execution_count": 1,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "-cAa9Yd_xSgd"
},
"source": [
""
],
"execution_count": 1,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "i1tg7dfRUaKw",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "1979f1be-f1e3-42d3-c503-a51bc82692a9"
},
"source": [
"import tensorflow as tf\n",
"from tensorflow import feature_column\n",
"import numpy as np\n",
"def _initialized_session(config=None):\n",
" sess = tf.compat.v1.Session(config=config)\n",
" sess.run(tf.compat.v1.initializers.global_variables())\n",
" sess.run(tf.compat.v1.tables_initializer())\n",
" return sess\n",
"\n",
"with tf.Graph().as_default():\n",
" # Define categorical colunm for our text feature, which is preprocessed into sequence of tokens\n",
" text_column = feature_column.sequence_categorical_column_with_vocabulary_list(key='text', vocabulary_list=list(['asd', 'asdf']))\n",
"\n",
" max_length = 6\n",
" sequence_feature_layer_inputs = {}\n",
" sequence_feature_layer_inputs['text'] = tf.keras.Input(\n",
" shape=(max_length,), name='text', dtype=tf.string)\n",
"\n",
" text_embedding = feature_column.shared_embeddings([text_column], dimension=64)\n",
"\n",
" # below is ok to save\n",
" # text_embedding = feature_column.embedding_column(text_column, dimension=8)\n",
"\n",
" # Define SequenceFeatures layer to pass feature_columns into Keras model\n",
" sequence_feature_layer = tf.keras.experimental.SequenceFeatures(text_embedding)\n",
"\n",
" # note here that SequenceFeatures layer produce tuple of two tensors as output. We need just first to pass next.\n",
" sequence_feature_layer_outputs, _ = sequence_feature_layer(\n",
" sequence_feature_layer_inputs)\n",
" x = tf.keras.layers.Conv1D(8, 4)(sequence_feature_layer_outputs)\n",
" x = tf.keras.layers.MaxPooling1D(2)(x)\n",
" x = tf.keras.layers.Dense(256, activation='relu')(x)\n",
" x = tf.keras.layers.Dropout(0.2)(x)\n",
" x = tf.keras.layers.GlobalAveragePooling1D()(x)\n",
" # This example supposes binary classification, as labels are 0 or 1\n",
" x = tf.keras.layers.Dense(1, activation='sigmoid')(x)\n",
"\n",
" model = tf.keras.models.Model(\n",
" inputs=[v for v in sequence_feature_layer_inputs.values()], outputs=x)\n",
"\n",
" model.summary()\n",
"\n",
" # This example supposes binary classification, as labels are 0 or 1\n",
" model.compile(optimizer='adam',\n",
" loss='binary_crossentropy',\n",
" metrics=['accuracy']\n",
" #run_eagerly=True\n",
" )\n",
" with _initialized_session():\n",
" model.fit({'text': np.array([[\"hello\"]*6])})\n",
" model.save(\"models\")"
],
"execution_count": 2,
"outputs": [
{
"output_type": "stream",
"text": [
"Model: \"model\"\n",
"_________________________________________________________________\n",
"Layer (type) Output Shape Param # \n",
"=================================================================\n",
"text (InputLayer) [(None, 6)] 0 \n",
"_________________________________________________________________\n",
"sequence_features (SequenceF ((None, None, 64), (None, 0 \n",
"_________________________________________________________________\n",
"conv1d (Conv1D) (None, None, 8) 2056 \n",
"_________________________________________________________________\n",
"max_pooling1d (MaxPooling1D) (None, None, 8) 0 \n",
"_________________________________________________________________\n",
"dense (Dense) (None, None, 256) 2304 \n",
"_________________________________________________________________\n",
"dropout (Dropout) (None, None, 256) 0 \n",
"_________________________________________________________________\n",
"global_average_pooling1d (Gl (None, 256) 0 \n",
"_________________________________________________________________\n",
"dense_1 (Dense) (None, 1) 257 \n",
"=================================================================\n",
"Total params: 4,617\n",
"Trainable params: 4,617\n",
"Non-trainable params: 0\n",
"_________________________________________________________________\n",
"Train on 1 samples\n",
"1/1 [==============================] - 0s 157ms/sample - loss: 0.6931 - accuracy: 0.0000e+00\n",
"INFO:tensorflow:Assets written to: models/assets\n"
],
"name": "stdout"
}
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment