Skip to content

Instantly share code, notes, and snippets.

@funwarioisii
Created November 13, 2019 09:40
Show Gist options
  • Save funwarioisii/06f5ffb419ac2cc66d8bbcbe3197322b to your computer and use it in GitHub Desktop.
Save funwarioisii/06f5ffb419ac2cc66d8bbcbe3197322b to your computer and use it in GitHub Desktop.

tf.data.DatasetでLeakage発生してた話

あるデータセットがあり,それをtrainデータとvalidationデータに分けたかった

master_ds = tf.data.Dataset.from_tensor_slices([i for i in range(10000)]).shuffle(10000)

train_ds = master_ds.take(6000)
valid_ds = master_ds.skip(6000).take(4000)

master_dsは今回考える10000件のデータ.そのまま拾ってくると綺麗にソートされているので先頭5000件と,飛ばして5000件とするとラベルが非常に偏る

そのため1度シャッフルする必要がある

次に先頭から6000件,4000件と取得する.

train_val = [v for v in train_ds]
valid_val = [v for v in valid_ds]

for v in valid_val:
    if v in train_val:
        print(v)

train_dsvalid_dsを配列に直し,train_valvalid_valとした(名前良くないね)

ここでvalid_valの要素がtrain_valに入っていればリークが発生していることになる

実際に実行すると,いくつか出力されることが過確認できる

対策1

tf.data.Datasetになる前にマスターの要素をshuffleしておく

import random
x = [i for i in range(10000)]
random.shuffle(x)

master_ds = tf.data.Dataset.from_tensor_slices(x)

これが可能であれば,最初からシャッフルした状態で使える

対策2

def is_test(i, x):
  return (i % 2) == 0
  
def is_train(i, x):
  return not is_test(i, x)

def restore(i, x):
  return x

train = master.enumrate().filter(is_train).map(restore)

test = master.enumrate().filter(is_test).map(restore)

こういう方法も考えられる

やや柔軟性が欠けているかもしれない

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment