Last active
January 5, 2024 09:54
-
-
Save deeperunderstanding/470795e5492b6f7ed0c14aac0455b8fc to your computer and use it in GitHub Desktop.
Final_Public.ipynb
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "colab_type": "text", | |
| "id": "HlGfk2GNUzye" | |
| }, | |
| "source": [ | |
| "# Data Exploration with Adversarial Autoencoders" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "colab": {}, | |
| "colab_type": "code", | |
| "id": "GWtC9DxiAfJY" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "#install additional dependencies if needed\n", | |
| "#I assume you have tensorflow, keras, numpy and pandas installed\n", | |
| "!pip install requests\n", | |
| "!pip install tables\n", | |
| "!pip install arrow\n", | |
| "!pip install ta\n", | |
| "!pip install tqdm" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/", | |
| "height": 35 | |
| }, | |
| "colab_type": "code", | |
| "id": "YJDai8iT07ZZ", | |
| "outputId": "df4e5d91-c538-4c8a-f7fe-4a9eba260d8a" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "import ta\n", | |
| "import matplotlib.pyplot as plt\n", | |
| "import pandas as pd\n", | |
| "import numpy as np\n", | |
| "import arrow as time\n", | |
| "import requests as http\n", | |
| "from tqdm import tqdm\n", | |
| "\n", | |
| "from keras.layers.advanced_activations import PReLU\n", | |
| "from keras.optimizers import Nadam\n", | |
| "from keras.models import Model\n", | |
| "from keras.layers import *" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "colab_type": "text", | |
| "id": "qd8ez9f-OiLN" | |
| }, | |
| "source": [ | |
| "## Loading the Data from Oanda" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "colab": {}, | |
| "colab_type": "code", | |
| "id": "syqusdQ_TZTu" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "class InstrumentLoader:\n", | |
| " def __init__(self, config):\n", | |
| " self.config = config\n", | |
| " self.store = pd.HDFStore('oanda_api_store.h5')\n", | |
| " self.time_format = 'YYYY-MM-DDTHH:mm:ss.SSSSSSSSZ'\n", | |
| " self.inst_base_url = 'https://api-fxpractice.oanda.com/v3/instruments/'\n", | |
| "\n", | |
| " def load_period(self, instrument, granularity, start, end):\n", | |
| " time_range = time.Arrow.range('day', start, end)\n", | |
| "\n", | |
| " trading_days = [\n", | |
| " day for day in time_range if day.format('d') not in ['6']\n", | |
| " ]\n", | |
| "\n", | |
| " signals = []\n", | |
| " for i in tqdm(range(len(trading_days)), desc=\"downloading data\"):\n", | |
| " day = trading_days[i]\n", | |
| " values = self.load_into_df(day, instrument, granularity)\n", | |
| " signals.append(values)\n", | |
| " return signals\n", | |
| "\n", | |
| " def load_into_df(self, day, instrument, granularity):\n", | |
| " day_key = instrument + granularity + day.format('YYYYMMDD')\n", | |
| " if day_key not in self.store:\n", | |
| "\n", | |
| " values = self.load_day(day, instrument, granularity, \"M\")\n", | |
| " \n", | |
| " day_close = pd.Series([\n", | |
| " float(candle['mid']['c']) for candle in values\n", | |
| " ]).rename('close')\n", | |
| " day_open = pd.Series([\n", | |
| " float(candle['mid']['o']) for candle in values\n", | |
| " ]).rename('open')\n", | |
| " day_high = pd.Series([\n", | |
| " float(candle['mid']['h']) for candle in values\n", | |
| " ]).rename('high')\n", | |
| " day_low = pd.Series([\n", | |
| " float(candle['mid']['l']) for candle in values\n", | |
| " ]).rename('low')\n", | |
| " day_volume = pd.Series([\n", | |
| " int(candle['volume']) for candle in values\n", | |
| " ]).rename('volume')\n", | |
| "\n", | |
| " time_of_day = pd.Series([\n", | |
| " int(time.get(candle['time'], self.time_format).format('HHmmss'))\n", | |
| " for candle in values\n", | |
| " ]).rename('time_of_day')\n", | |
| "\n", | |
| " signals = [\n", | |
| " time_of_day, day_open, day_close, day_high, day_low, day_volume\n", | |
| " ]\n", | |
| "\n", | |
| " raw_day = pd.concat(signals, axis=1).set_index('time_of_day')\n", | |
| "\n", | |
| " self.store[day_key] = raw_day\n", | |
| " return raw_day\n", | |
| " else:\n", | |
| " raw_day = self.store[day_key]\n", | |
| " return raw_day\n", | |
| "\n", | |
| " def load_day(self, day, instrument, granularity, price):\n", | |
| " time_format = 'YYYY-MM-DDTHH:mm:ssZ'\n", | |
| " base_uri = self.inst_base_url + instrument + \"/candles\"\n", | |
| " headers = {\"Authorization\": self.config['token']}\n", | |
| " parameters = {\n", | |
| " \"from\": day.format(time_format),\n", | |
| " \"to\": day.shift(days=1).format(time_format),\n", | |
| " \"price\": price,\n", | |
| " \"granularity\": granularity,\n", | |
| " \"includeFirst\": \"True\",\n", | |
| " }\n", | |
| " response = http.get(\n", | |
| " base_uri, params=parameters, headers=headers).json()\n", | |
| " return response['candles']\n", | |
| "\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/", | |
| "height": 35 | |
| }, | |
| "colab_type": "code", | |
| "id": "W3UfSM4vJ9DL", | |
| "outputId": "67db532c-d6fe-4f53-ca3a-e54e4e099d1b" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "config = {\n", | |
| " 'token': \"Bearer <YOUR API TOKEN HERE!>\",\n", | |
| " }\n", | |
| "\n", | |
| "start = time.get(\"2018-01-02T00:00:00Z\", 'YYYY-MM-DDTHH:mm:ss')\n", | |
| "end = time.get(\"2018-12-30T00:00:00Z\", 'YYYY-MM-DDTHH:mm:ss')\n", | |
| "api = InstrumentLoader(config)\n", | |
| "\n", | |
| "days = api.load_period('EUR_USD', 'M5', start, end)\n", | |
| "\n", | |
| "test = days[0]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "colab_type": "text", | |
| "id": "u3yI_Ul9FTp2" | |
| }, | |
| "source": [ | |
| "## Enhancing and preprocessing data" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "colab": {}, | |
| "colab_type": "code", | |
| "id": "pIiz7eUgY9o9" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "def roll_stats(data, window=32):\n", | |
| " name = data.name\n", | |
| " data_mean = data.rolling(window).mean().rename(name + '_mean')\n", | |
| " data_var = data.rolling(window).var().rename(name + '_var')\n", | |
| " data_skew = data.rolling(window).skew().rename(name + '_skew')\n", | |
| " return pd.concat([data_mean, data_var, data_skew], axis=1)\n", | |
| "\n", | |
| "def log_returns(frame):\n", | |
| " return np.log(frame) - np.log(frame.shift(1))\n", | |
| "\n", | |
| "enhanced_days = []\n", | |
| "\n", | |
| "for i in range(len(days)):\n", | |
| " day = days[i]\n", | |
| "\n", | |
| " high = day['high']\n", | |
| " low = day['low']\n", | |
| " close = day['close']\n", | |
| " volume = day['volume']\n", | |
| "\n", | |
| " close_returns = log_returns(close)\n", | |
| " close_stats = roll_stats(close_returns)\n", | |
| " rsi = ta.momentum.rsi(close, n=13, fillna=False).rename('rsi')\n", | |
| " atr = ta.volatility.average_true_range(high, low, close, n=13, fillna=False).rename('atr')\n", | |
| " macd_signal = ta.trend.macd_signal(close, n_fast=13, n_slow=35, n_sign=8, fillna=False).rename('macd_signal')\n", | |
| "\n", | |
| " all_data = pd.concat([close_returns, log_returns(high), log_returns(low), close_stats, volume, rsi, atr, macd_signal], axis=1)\n", | |
| " aligned_data = all_data.dropna()\n", | |
| " enhanced_days.append(aligned_data)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "colab": { | |
| "base_uri": "https://localhost:8080/", | |
| "height": 35 | |
| }, | |
| "colab_type": "code", | |
| "id": "rJNsTW7RZJIj", | |
| "outputId": "a3106c35-ded4-4631-dae4-bfaef9338a43" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "data_x = []\n", | |
| "\n", | |
| "window_size = 32\n", | |
| "step_size = 16\n", | |
| "\n", | |
| "all_days = pd.concat(enhanced_days)\n", | |
| "all_mean = all_days.mean()\n", | |
| "all_std = all_days.std()\n", | |
| "\n", | |
| "windowed_signals = []\n", | |
| "\n", | |
| "for i in range(len(enhanced_days)):\n", | |
| " signal = enhanced_days[i]\n", | |
| "\n", | |
| " if len(signal) >= window_size:\n", | |
| " for j in range(0, len(signal) - window_size, step_size):\n", | |
| " window_x = signal.iloc[j:j + window_size]\n", | |
| " windowed_signals.append(window_x)\n", | |
| " window_x = (window_x - window_x.mean()) / window_x.std()\n", | |
| " data_x.append(window_x.values)\n", | |
| " \n", | |
| "\n", | |
| "data_x = np.array(data_x)\n", | |
| "\n", | |
| "split_train = int(len(data_x) * 0.7)\n", | |
| "\n", | |
| "train_x = data_x[:split_train]\n", | |
| "originals_x = windowed_signals[:split_train]\n", | |
| "\n", | |
| "test_x = data_x[split_train:]\n", | |
| "\n", | |
| "print(train_x.shape)\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "colab": {}, | |
| "colab_type": "code", | |
| "id": "GtLo3FmgeYGf" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "def sample_normal(latent_dim, batch_size, window_size=None):\n", | |
| " shape = (batch_size, latent_dim) if window_size is None else (batch_size, window_size, latent_dim)\n", | |
| " return np.random.normal(size=shape)\n", | |
| " \n", | |
| "def sample_categories(cat_dim, batch_size):\n", | |
| " cats = np.zeros((batch_size, cat_dim))\n", | |
| " for i in range(batch_size):\n", | |
| " one = np.random.randint(0, cat_dim)\n", | |
| " cats[i][one] = 1\n", | |
| " return cats\n", | |
| " \n", | |
| "def create_encoder(latent_dim, cat_dim, window_size, input_dim):\n", | |
| " input_layer = Input(shape=(window_size, input_dim))\n", | |
| " \n", | |
| " code = TimeDistributed(Dense(64, activation='linear'))(input_layer)\n", | |
| " code = Bidirectional(LSTM(128, return_sequences=True))(code)\n", | |
| " code = BatchNormalization()(code)\n", | |
| " code = ELU()(code)\n", | |
| " code = Bidirectional(LSTM(64))(code)\n", | |
| " code = BatchNormalization()(code)\n", | |
| " code = ELU()(code)\n", | |
| " \n", | |
| " cat = Dense(64)(code)\n", | |
| " cat = BatchNormalization()(cat)\n", | |
| " cat = PReLU()(cat)\n", | |
| " cat = Dense(cat_dim, activation='softmax')(cat)\n", | |
| " \n", | |
| " latent_repr = Dense(64)(code)\n", | |
| " latent_repr = BatchNormalization()(latent_repr)\n", | |
| " latent_repr = PReLU()(latent_repr)\n", | |
| " latent_repr = Dense(latent_dim, activation='linear')(latent_repr)\n", | |
| " \n", | |
| " decode = Concatenate()([latent_repr, cat])\n", | |
| " decode = RepeatVector(window_size)(decode)\n", | |
| " decode = Bidirectional(LSTM(64, return_sequences=True))(decode)\n", | |
| " decode = ELU()(decode)\n", | |
| " decode = Bidirectional(LSTM(128, return_sequences=True))(decode)\n", | |
| " decode = ELU()(decode)\n", | |
| " decode = TimeDistributed(Dense(64))(decode)\n", | |
| " decode = ELU()(decode)\n", | |
| " decode = TimeDistributed(Dense(input_dim, activation='linear'))(decode)\n", | |
| " \n", | |
| " error = Subtract()([input_layer, decode])\n", | |
| " \n", | |
| " return Model(input_layer, [decode, latent_repr, cat, error])\n", | |
| "\n", | |
| " \n", | |
| "def create_discriminator(latent_dim):\n", | |
| " input_layer = Input(shape=(latent_dim,))\n", | |
| " disc = Dense(128)(input_layer)\n", | |
| " disc = ELU()(disc) #LeakyReLU(alpha=0.2)(disc)\n", | |
| " disc = Dense(64)(disc)\n", | |
| " disc = ELU()(disc) #LeakyReLU(alpha=0.2)(disc)\n", | |
| " disc = Dense(1, activation=\"sigmoid\")(disc)\n", | |
| " \n", | |
| " model = Model(input_layer, disc)\n", | |
| " return model" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "colab_type": "text", | |
| "id": "xYdQKItBhcbG" | |
| }, | |
| "source": [ | |
| "## Encoding Data to find Clusters\n", | |
| "\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "colab": {}, | |
| "colab_type": "code", | |
| "id": "bbD3wNhkhYaU" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "window_size = train_x.shape[1]\n", | |
| "input_dim = train_x.shape[2]\n", | |
| "latent_dim = 32\n", | |
| "cat_dim = 8\n", | |
| "\n", | |
| "prior_discriminator = create_discriminator(latent_dim)\n", | |
| "prior_discriminator.compile(loss='binary_crossentropy', \n", | |
| " optimizer=Nadam(0.0002, 0.5), \n", | |
| " metrics=['accuracy'])\n", | |
| "\n", | |
| "prior_discriminator.trainable = False\n", | |
| "\n", | |
| "cat_discriminator = create_discriminator(cat_dim)\n", | |
| "cat_discriminator.compile(loss='binary_crossentropy', \n", | |
| " optimizer=Nadam(0.0002, 0.5), \n", | |
| " metrics=['accuracy'])\n", | |
| "\n", | |
| "cat_discriminator.trainable = False\n", | |
| "\n", | |
| "encoder = create_encoder(latent_dim, cat_dim, window_size, input_dim)\n", | |
| "\n", | |
| "signal_in = Input(shape=(window_size, input_dim))\n", | |
| "reconstructed_signal, encoded_repr, category, _ = encoder(signal_in)\n", | |
| "\n", | |
| "is_real_prior = prior_discriminator(encoded_repr)\n", | |
| "is_real_cat = cat_discriminator(category)\n", | |
| "\n", | |
| "autoencoder = Model(signal_in, [reconstructed_signal, is_real_prior, is_real_cat])\n", | |
| "autoencoder.compile(loss=['mse', 'binary_crossentropy', 'binary_crossentropy'],\n", | |
| " loss_weights=[0.99, 0.005, 0.005],\n", | |
| " optimizer=Nadam(0.0002, 0.5))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "colab": {}, | |
| "colab_type": "code", | |
| "id": "ljwfe7b6TTig" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "autoencoder.summary()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "colab": {}, | |
| "colab_type": "code", | |
| "id": "O-qaYIKahpxU" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "batches = 10000\n", | |
| "batch_size=64\n", | |
| "\n", | |
| "losses_disc = []\n", | |
| "losses_disc_cat = []\n", | |
| "losses_ae = []\n", | |
| "losses_val = []\n", | |
| "\n", | |
| "real = np.ones((batch_size, 1))\n", | |
| "fake = np.zeros((batch_size, 1))\n", | |
| "\n", | |
| "def discriminator_training(discriminator, real, fake):\n", | |
| " def train(real_samples, fake_samples):\n", | |
| " discriminator.trainable = True\n", | |
| "\n", | |
| " loss_real = discriminator.train_on_batch(real_samples, real)\n", | |
| " loss_fake = discriminator.train_on_batch(fake_samples, fake)\n", | |
| " loss = np.add(loss_real, loss_fake) * 0.5\n", | |
| "\n", | |
| " discriminator.trainable = False\n", | |
| "\n", | |
| " return loss\n", | |
| " return train\n", | |
| "\n", | |
| "train_prior_discriminator = discriminator_training(prior_discriminator, real, fake)\n", | |
| "train_cat_discriminator = discriminator_training(cat_discriminator, real, fake)\n", | |
| "\n", | |
| "pbar = tqdm(range(batches))\n", | |
| "\n", | |
| "for _ in pbar:\n", | |
| " \n", | |
| " ids = np.random.randint(0, train_x.shape[0], batch_size)\n", | |
| " signals = train_x[ids]\n", | |
| "\n", | |
| " _, latent_fake, category_fake, _ = encoder.predict(signals)\n", | |
| "\n", | |
| " latent_real = sample_normal(latent_dim, batch_size)\n", | |
| " category_real = sample_categories(cat_dim, batch_size)\n", | |
| "\n", | |
| " prior_loss = train_prior_discriminator(latent_real, latent_fake)\n", | |
| " cat_loss = train_cat_discriminator(category_real, category_fake)\n", | |
| "\n", | |
| " losses_disc.append(prior_loss)\n", | |
| " losses_disc_cat.append(cat_loss)\n", | |
| "\n", | |
| " encoder_loss = autoencoder.train_on_batch(signals, [signals, real, real])\n", | |
| " losses_ae.append(encoder_loss)\n", | |
| "\n", | |
| " val_loss = autoencoder.test_on_batch(signals, [signals, real, real])\n", | |
| " losses_val.append(val_loss)\n", | |
| "\n", | |
| " pbar.set_description(\"[Acc. Prior/Cat: %.2f%% / %.2f%%] [MSE train/val: %f / %f]\" \n", | |
| " % (100*prior_loss[1], 100*cat_loss[1], encoder_loss[1], val_loss[1]))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "colab": {}, | |
| "colab_type": "code", | |
| "id": "FTc-3lrmWY3_" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "autoencoder.save_weights('aae.hdf')" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "colab_type": "text", | |
| "id": "XAbmtQ5JlTXc" | |
| }, | |
| "source": [ | |
| "## Show Loss and Result Samples" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "colab": {}, | |
| "colab_type": "code", | |
| "id": "vgD0vTjqhwi9" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "fig, axes = plt.subplots(nrows=1, ncols=3)\n", | |
| "fig.set_size_inches(30, 6)\n", | |
| "\n", | |
| "axes[0].plot([loss[1] for loss in losses_disc])\n", | |
| "axes[1].plot([loss[1] for loss in losses_disc_cat])\n", | |
| "axes[2].plot([loss[1] for loss in losses_ae])\n", | |
| "axes[2].plot([loss[1] for loss in losses_val])\n", | |
| "\n", | |
| "fig.show()\n", | |
| "\n", | |
| "size = 5\n", | |
| "offset = 5\n", | |
| "\n", | |
| "fig, axes = plt.subplots(nrows=size, ncols=5)\n", | |
| "fig.set_size_inches(20, 3 * size)\n", | |
| "\n", | |
| "(dec, rep, cat, error) = encoder.predict(test_x[offset:size+offset])\n", | |
| "\n", | |
| "for i in range(size): \n", | |
| " axes[i,0].plot(test_x[i+offset])\n", | |
| " axes[i,1].imshow(rep[i].reshape(8,4))\n", | |
| " axes[i,2].imshow(cat[i].reshape(cat_dim, 1))\n", | |
| " axes[i,3].plot(dec[i])\n", | |
| " axes[i,4].plot(error[i])\n", | |
| "\n", | |
| "fig.show()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "colab_type": "text", | |
| "id": "67J6j5B9_C0f" | |
| }, | |
| "source": [ | |
| "## Inspect Categories" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "colab": {}, | |
| "colab_type": "code", | |
| "id": "w1bQ9LiwozJp" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "(dec, rep, cat, error) = encoder.predict(data_x)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "colab": {}, | |
| "colab_type": "code", | |
| "id": "vD4v_pDlo1dy" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "from collections import Counter\n", | |
| "\n", | |
| "categories = [np.argmax(item) for item in cat]\n", | |
| "counts = Counter(categories)\n", | |
| "print(counts)\n", | |
| "\n", | |
| "labels, count = zip(*counts.items())\n", | |
| "\n", | |
| "plt.figure(figsize=(14,7))\n", | |
| "plt.bar(labels, count)\n", | |
| "plt.show()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "colab_type": "text", | |
| "id": "uBYzN5-Xo5ar" | |
| }, | |
| "source": [ | |
| "## Latent Distribution" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "colab": {}, | |
| "colab_type": "code", | |
| "id": "1zdvcq-fo48T" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "from sklearn.manifold import TSNE\n", | |
| "\n", | |
| "fig, axes = plt.subplots(nrows=1, ncols=3)\n", | |
| "fig.set_size_inches(30,7)\n", | |
| "\n", | |
| "X = TSNE(n_components=2, perplexity=10).fit_transform(rep)\n", | |
| "\n", | |
| "axes[0].scatter([x[0] for x in X], [x[1] for x in X], c=categories, cmap='viridis')\n", | |
| "\n", | |
| "X = TSNE(n_components=2, perplexity=30).fit_transform(rep)\n", | |
| "\n", | |
| "axes[1].scatter([x[0] for x in X], [x[1] for x in X], c=categories, cmap='viridis')\n", | |
| "\n", | |
| "X = TSNE(n_components=2, perplexity=50).fit_transform(rep)\n", | |
| "\n", | |
| "axes[2].scatter([x[0] for x in X], [x[1] for x in X], c=categories, cmap='viridis')\n", | |
| "\n", | |
| "fig.show()" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "colab_type": "text", | |
| "id": "JC2LSJbzpCrT" | |
| }, | |
| "source": [ | |
| "## Categories in the Input Data\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": { | |
| "colab": {}, | |
| "colab_type": "code", | |
| "id": "Fuj8zjvxpGIN" | |
| }, | |
| "outputs": [], | |
| "source": [ | |
| "from sklearn.manifold import TSNE\n", | |
| "\n", | |
| "fig, axes = plt.subplots(nrows=1, ncols=3)\n", | |
| "fig.set_size_inches(30,7)\n", | |
| "\n", | |
| "input_data = data_x.reshape(data_x.shape[0], data_x.shape[1] * data_x.shape[2])\n", | |
| "\n", | |
| "X = TSNE(n_components=2, perplexity=10).fit_transform(input_data)\n", | |
| "\n", | |
| "axes[0].scatter([x[0] for x in X], [x[1] for x in X], c=categories, cmap='viridis')\n", | |
| "\n", | |
| "X = TSNE(n_components=2, perplexity=30).fit_transform(input_data)\n", | |
| "\n", | |
| "axes[1].scatter([x[0] for x in X], [x[1] for x in X], c=categories, cmap='viridis')\n", | |
| "\n", | |
| "X = TSNE(n_components=2, perplexity=50).fit_transform(input_data)\n", | |
| "\n", | |
| "axes[2].scatter([x[0] for x in X], [x[1] for x in X], c=categories, cmap='viridis')" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| } | |
| ], | |
| "metadata": { | |
| "accelerator": "TPU", | |
| "colab": { | |
| "collapsed_sections": [], | |
| "name": "Final_Public.ipynb", | |
| "provenance": [], | |
| "version": "0.3.2" | |
| }, | |
| "kernelspec": { | |
| "display_name": "Python 3", | |
| "language": "python", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.7.3" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 2 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment