Instantly share code, notes, and snippets.
Last active
August 5, 2021 01:11
-
Star
0
(0)
You must be signed in to star a gist -
Fork
0
(0)
You must be signed in to fork a gist
-
Save austinzh/ad4c4508bb3793ad1815e6f845b7deaa to your computer and use it in GitHub Desktop.
shared_embedding.ipynb
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"name": "Untitled86.ipynb", | |
"provenance": [], | |
"collapsed_sections": [], | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"language_info": { | |
"name": "python" | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/austinzh/ad4c4508bb3793ad1815e6f845b7deaa/untitled86.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "E1e4lKDUo1nA" | |
}, | |
"source": [ | |
"import tensorflow as tf\n" | |
], | |
"execution_count": 1, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "-cAa9Yd_xSgd" | |
}, | |
"source": [ | |
"" | |
], | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "i1tg7dfRUaKw", | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"outputId": "fab308a5-d543-4aa9-c6d3-de3cf660aa6d" | |
}, | |
"source": [ | |
"import tensorflow as tf\n", | |
"from tensorflow import feature_column\n", | |
"import numpy as np\n", | |
"tf.compat.v1.disable_eager_execution()\n", | |
"def _initialized_session(config=None):\n", | |
" sess = tf.compat.v1.Session(config=config)\n", | |
" sess.run(tf.compat.v1.initializers.global_variables())\n", | |
" sess.run(tf.compat.v1.tables_initializer())\n", | |
" return sess\n", | |
"\n", | |
"with tf.Graph().as_default():\n", | |
" # Define categorical colunm for our text feature, which is preprocessed into sequence of tokens\n", | |
" text_column = feature_column.sequence_categorical_column_with_vocabulary_list(key='text', vocabulary_list=list(['asd', 'asdf']))\n", | |
"\n", | |
" max_length = 6\n", | |
" sequence_feature_layer_inputs = {}\n", | |
" sequence_feature_layer_inputs['text'] = tf.keras.Input(\n", | |
" shape=(max_length,), name='text', dtype=tf.string)\n", | |
"\n", | |
" text_embedding = feature_column.shared_embeddings([text_column], dimension=64)\n", | |
"\n", | |
" # below is ok to save\n", | |
" # text_embedding = feature_column.embedding_column(text_column, dimension=8)\n", | |
"\n", | |
" # Define SequenceFeatures layer to pass feature_columns into Keras model\n", | |
" sequence_feature_layer = tf.keras.experimental.SequenceFeatures(text_embedding)\n", | |
"\n", | |
" # note here that SequenceFeatures layer produce tuple of two tensors as output. We need just first to pass next.\n", | |
" sequence_feature_layer_outputs, _ = sequence_feature_layer(\n", | |
" sequence_feature_layer_inputs)\n", | |
" x = tf.keras.layers.Conv1D(8, 4)(sequence_feature_layer_outputs)\n", | |
" x = tf.keras.layers.MaxPooling1D(2)(x)\n", | |
" x = tf.keras.layers.Dense(256, activation='relu')(x)\n", | |
" x = tf.keras.layers.Dropout(0.2)(x)\n", | |
" x = tf.keras.layers.GlobalAveragePooling1D()(x)\n", | |
" # This example supposes binary classification, as labels are 0 or 1\n", | |
" x = tf.keras.layers.Dense(1, activation='sigmoid')(x)\n", | |
"\n", | |
" model = tf.keras.models.Model(\n", | |
" inputs=[v for v in sequence_feature_layer_inputs.values()], outputs=x)\n", | |
"\n", | |
" model.summary()\n", | |
"\n", | |
" # This example supposes binary classification, as labels are 0 or 1\n", | |
" model.compile(optimizer='adam',\n", | |
" loss='binary_crossentropy',\n", | |
" metrics=['accuracy']\n", | |
" #run_eagerly=True\n", | |
" )\n", | |
" model.build(input_shape={'text': (max_length,)})\n", | |
" with _initialized_session() as sess:\n", | |
" model.fit({'text': np.array([[\"hello\"]*6])})\n", | |
" model.save(\"models\")" | |
], | |
"execution_count": 10, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Model: \"model\"\n", | |
"_________________________________________________________________\n", | |
"Layer (type) Output Shape Param # \n", | |
"=================================================================\n", | |
"text (InputLayer) [(None, 6)] 0 \n", | |
"_________________________________________________________________\n", | |
"sequence_features (SequenceF ((None, None, 64), (None, 0 \n", | |
"_________________________________________________________________\n", | |
"conv1d (Conv1D) (None, None, 8) 2056 \n", | |
"_________________________________________________________________\n", | |
"max_pooling1d (MaxPooling1D) (None, None, 8) 0 \n", | |
"_________________________________________________________________\n", | |
"dense (Dense) (None, None, 256) 2304 \n", | |
"_________________________________________________________________\n", | |
"dropout (Dropout) (None, None, 256) 0 \n", | |
"_________________________________________________________________\n", | |
"global_average_pooling1d (Gl (None, 256) 0 \n", | |
"_________________________________________________________________\n", | |
"dense_1 (Dense) (None, 1) 257 \n", | |
"=================================================================\n", | |
"Total params: 4,617\n", | |
"Trainable params: 4,617\n", | |
"Non-trainable params: 0\n", | |
"_________________________________________________________________\n", | |
"Train on 1 samples\n", | |
"1/1 [==============================] - 1s 627ms/sample - loss: 0.6931 - accuracy: 0.0000e+00\n", | |
"INFO:tensorflow:Assets written to: models/assets\n" | |
], | |
"name": "stdout" | |
} | |
] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment