Last active
August 7, 2020 18:03
-
-
Save funwarioisii/7300833313405e499faa1f65ee207c71 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# TensorFlow で `in` 演算子が使えない\n", | |
"\n", | |
"要素がリストに入ってるかを返す`in`演算子の話です\n", | |
"\n", | |
"tf.dataを使ってロードしたvalidationデータセットがあり,あるモデルの演算結果間違えたデータを操作したい場面があった" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import tensorflow as tf\n", | |
"valid_ds = tf.data.Dataset.range(10)\n", | |
"mistake_idx = [2, 3, 5]" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## \n", | |
"\n", | |
"インデックスがわかっているので,直感的には`valid_ds`を`enumerate`してインデックスが間違えたインデックスリストに入っていれば抽出したい" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def _in(idx, data):\n", | |
" return idx in mistake_idx\n", | |
"\n", | |
"valid_ds.enumerate().filter(_in)\n", | |
"# valid_ds.enumerate().filter(tf.function(lambda idx, data: idx in mistake_idx))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"これを実行すると以下のエラーを発生する\n", | |
"ちなみにエラー文に従って`@tf.function`をデコレートしても動作しない\n", | |
"AutoGraphは`in`演算子をサポートしてないらしい\n", | |
"\n", | |
"\n", | |
"```\n", | |
"OperatorNotAllowedInGraphError: in converted code:\n", | |
"\n", | |
" <ipython-input-38-0a35c9047825>:4 None *\n", | |
" valid_ds.enumerate().filter(lambda idx, data: idx in mistake_idx)\n", | |
" /Users/uniikura/miniconda3/envs/py38/lib/python3.7/site-packages/tensorflow_core/python/framework/ops.py:765 __bool__\n", | |
" self._disallow_bool_casting()\n", | |
" /Users/uniikura/miniconda3/envs/py38/lib/python3.7/site-packages/tensorflow_core/python/framework/ops.py:528 _disallow_bool_casting\n", | |
" \"using a `tf.Tensor` as a Python `bool`\")\n", | |
" /Users/uniikura/miniconda3/envs/py38/lib/python3.7/site-packages/tensorflow_core/python/framework/ops.py:513 _disallow_when_autograph_disabled\n", | |
" \" Try decorating it directly with @tf.function.\".format(task))\n", | |
"\n", | |
" OperatorNotAllowedInGraphError: using a `tf.Tensor` as a Python `bool` is not allowed: AutoGraph is disabled in this function. Try decorating it directly with @tf.function.\n", | |
"```" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## for文を使って中に入ってるかを確認させていく" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def _in(idx, data):\n", | |
" result = False\n", | |
" for miss in mistake_idx:\n", | |
" if miss == idx:\n", | |
" result = True\n", | |
" return result" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"(<tf.Tensor: id=862, shape=(), dtype=int64, numpy=2>, <tf.Tensor: id=863, shape=(), dtype=int64, numpy=2>)\n", | |
"(<tf.Tensor: id=864, shape=(), dtype=int64, numpy=3>, <tf.Tensor: id=865, shape=(), dtype=int64, numpy=3>)\n", | |
"(<tf.Tensor: id=866, shape=(), dtype=int64, numpy=5>, <tf.Tensor: id=867, shape=(), dtype=int64, numpy=5>)\n" | |
] | |
} | |
], | |
"source": [ | |
"for v in valid_ds.enumerate().filter(_in):\n", | |
" print(v)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## 行列演算を頑張る\n", | |
"\n", | |
"昨日`for文使って`がうまくいかなかったが,なんかうまく行ってしまった\n", | |
"どこかを間違えていたらしい\n", | |
"\n", | |
"`==`演算子を使ってインデックスリストを`bool`に変えて`reduce_sum`して様子を見る方法がある" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"(<tf.Tensor: id=930, shape=(), dtype=int64, numpy=2>, <tf.Tensor: id=931, shape=(), dtype=int64, numpy=2>)\n", | |
"(<tf.Tensor: id=932, shape=(), dtype=int64, numpy=3>, <tf.Tensor: id=933, shape=(), dtype=int64, numpy=3>)\n", | |
"(<tf.Tensor: id=934, shape=(), dtype=int64, numpy=5>, <tf.Tensor: id=935, shape=(), dtype=int64, numpy=5>)\n" | |
] | |
} | |
], | |
"source": [ | |
"def _in(idx, data):\n", | |
" return tf.cast(tf.math.reduce_sum(tf.cast(mistake_label == idx, tf.int8)), tf.bool)\n", | |
"\n", | |
"for i in ds.enumerate().filter(_in):\n", | |
" print(i)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## 後処理\n", | |
"indexはいらないので消す" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"tf.Tensor(2, shape=(), dtype=int64)\n", | |
"tf.Tensor(3, shape=(), dtype=int64)\n", | |
"tf.Tensor(5, shape=(), dtype=int64)\n" | |
] | |
} | |
], | |
"source": [ | |
"def _in(idx, data):\n", | |
" return tf.cast(tf.math.reduce_sum(tf.cast(mistake_label == idx, tf.int8)), tf.bool)\n", | |
"\n", | |
"def restore(idx, data):\n", | |
" return data\n", | |
"\n", | |
"for i in ds.enumerate().filter(_in).map(restore):\n", | |
" print(i)" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.7.3" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment