Created
August 13, 2019 22:49
-
-
Save rrmistry/a6095f530b998b2af3cf95ec0be0643c to your computer and use it in GitHub Desktop.
Complex model weights with Tensorflow 2.0 and Keras
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"!pip install tf-nightly-2-0-preview==2.0.0.dev20190812\n", | |
"!pip install tf-nightly-gpu-2-0-preview==2.0.0.dev20190812" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import tensorflow as tf\n", | |
"from tensorflow.keras.layers import LSTM, Dense, Input\n", | |
"import numpy as np" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"BATCH_SIZE = 5\n", | |
"TIME_STEP_SIZE = 50\n", | |
"FEATURE_SIZE = 1024\n", | |
"BATCH_SHAPE = (BATCH_SIZE, TIME_STEP_SIZE, FEATURE_SIZE)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"model = tf.keras.Sequential([\n", | |
" Input(batch_shape=(BATCH_SIZE, TIME_STEP_SIZE, FEATURE_SIZE),\n", | |
" dtype=tf.complex64),\n", | |
" LSTM(units=FEATURE_SIZE, dtype=tf.complex64, return_sequences=True),\n", | |
" Dense(units=FEATURE_SIZE, dtype=tf.complex64, activation='relu'),\n", | |
" LSTM(units=FEATURE_SIZE, dtype=tf.complex64, return_sequences=True),\n", | |
" Dense(units=FEATURE_SIZE, dtype=tf.complex64, activation='relu'),\n", | |
"])\n", | |
"model.compile(optimizer='adam', loss='MAPE')\n", | |
"model.summary(line_length=75)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def batch_generator():\n", | |
" xBatchRow = tf.complex(\n", | |
" real=np.random.rand(BATCH_SIZE, TIME_STEP_SIZE, FEATURE_SIZE),\n", | |
" imag=np.random.rand(BATCH_SIZE, TIME_STEP_SIZE, FEATURE_SIZE)\n", | |
" )\n", | |
" yBatchRow = tf.complex(\n", | |
" real=np.random.rand(BATCH_SIZE, TIME_STEP_SIZE, FEATURE_SIZE),\n", | |
" imag=np.random.rand(BATCH_SIZE, TIME_STEP_SIZE, FEATURE_SIZE)\n", | |
" )\n", | |
" yield xBatchRow, yBatchRow" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"model.fit_generator(generator=batch_generator())" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.5" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment