Created
March 3, 2017 08:39
-
-
Save izmailovpavel/7f692aadb20753eb79cb3a9fe986afdd to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"import tensorflow as tf\n", | |
"import sys" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"import t3f\n", | |
"from t3f import TensorTrain\n", | |
"from t3f.ops import *\n", | |
"import t3f.kronecker as kr\n", | |
"%load_ext autoreload\n", | |
"%autoreload 2" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"A = t3f.get_variable(name='A', initializer=t3f.random_matrix(shape=((2,2,2),(3,3,3)), tt_rank=2))\n", | |
"B = t3f.get_variable(name='B', initializer=t3f.random_matrix(shape=((2,2,2),(3,3,3)), tt_rank=3))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"i_shapes, j_shapes = A.get_raw_shape()\n", | |
"C_cores = []\n", | |
"for core_idx in range(A.ndims()):\n", | |
" core_A = A.tt_cores[core_idx]\n", | |
" core_B = B.tt_cores[core_idx]\n", | |
" A_r = A.get_tt_ranks()[core_idx]\n", | |
" B_r = B.get_tt_ranks()[core_idx]\n", | |
" A_r_next = A.get_tt_ranks()[core_idx+1]\n", | |
" B_r_next = B.get_tt_ranks()[core_idx+1]\n", | |
" \n", | |
" if core_idx == 0:\n", | |
" core_C = tf.concat(3, [core_A, core_B])\n", | |
" elif core_idx == A.ndims()-1:\n", | |
" core_C = tf.concat(0, [core_A, core_B])\n", | |
" else:\n", | |
" core_C_1 = tf.concat(3, [core_A, tf.zeros((A_r, i_shapes[core_idx], j_shapes[core_idx], B_r_next))])\n", | |
" core_C_2 = tf.concat(3, [tf.zeros((B_r, i_shapes[core_idx], j_shapes[core_idx], A_r_next)), core_B])\n", | |
" core_C = tf.concat(0, [core_C_1, core_C_2])\n", | |
" C_cores.append(core_C)\n", | |
"C_shape = A.get_raw_shape()\n", | |
"C_ranks = [rank_A.value + rank_B.value for rank_A, rank_B in zip(A.get_tt_ranks(), B.get_tt_ranks())]\n", | |
"C_ranks[0] = 1\n", | |
"C_ranks[-1] = 1\n", | |
"C = t3f.TensorTrain(C_cores, C_shape, C_ranks)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"init_op = tf.global_variables_initializer()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"sess = tf.InteractiveSession()\n", | |
"sess.run(init_op)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"1.847928e-06" | |
] | |
}, | |
"execution_count": 7, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"np.linalg.norm(sess.run(t3f.ops.full(A))+sess.run(t3f.ops.full(B)) - sess.run(t3f.ops.full(C)) )" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"sess.close()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.4.3" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 0 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment