Skip to content

Instantly share code, notes, and snippets.

@akirayou
Last active March 23, 2018 00:31
Show Gist options
  • Save akirayou/72a32e1df9a8415a1743a1dd63d24a59 to your computer and use it in GitHub Desktop.
Save akirayou/72a32e1df9a8415a1743a1dd63d24a59 to your computer and use it in GitHub Desktop.
Class Random choice iterator for chainer , Class is not Label. For example, It's for multi sample from one target object. this target's ID is class.
# -*- coding: utf-8 -*-
"""
Created on Tue Sep 27 02:01:52 2016
@author: a-noda
"""
import numpy as np
from chainer.dataset import iterator
class ClassRandomIterator(iterator.Iterator):
'''
random Choice clsid ,and select One realIdx from clsId
'''
def __init__(self, dataset,clsIds, batch_size, repeat=True,classMaskRate=1,classMaskIdx=None):
self.dataset = dataset
self.batch_size = batch_size
self._repeat = repeat
self.classNamesOrg,scid=np.unique(clsIds,return_inverse=True)
self.classInvIdx=scid
self.classList=[ np.nonzero(scid==c)[0] for c in range(np.max(scid)+1) ]
self.classListOrg=self.classList.copy()
#print("nof type of class (org)",len(self.classList))
self.remask(classMaskRate,classMaskIdx)
self.current_position = 0
self.epoch = 0
self.is_new_epoch = False
def rewind(self):
self.current_position = 0
self.epoch = 0
self.is_new_epoch = False
def clsIdToInvalidIdx(self,clsIds):
ret= (clsIds != self.classNames[0])
for c in self.classNames[1:]:
ret *= (clsIds!=c)
return np.nonzero(ret)[0]
def remask(self,classMaskRate=1,classMaskIdx=None):
self.classList=self.classListOrg.copy()
if(classMaskRate != 1 or (not classMaskIdx is None) ):
if(classMaskIdx is None):
self.classMaskIdx=np.random.choice(range(len(self.classList)), int(len(self.classList)*classMaskRate) )
else:
self.classMaskIdx=classMaskIdx
self.classList=[self.classList[i] for i in self.classMaskIdx]
self.classNames=self.classNamesOrg[self.classMaskIdx]
print("remask ",len(self.classList))
else:
self.classNames=self.classNamesOrg
def __next__(self):
if not self._repeat and self.epoch > 0:
raise StopIteration
N = len(self.dataset)
cindex=np.random.choice(len(self.classList),self.batch_size)
index=[np.random.choice(self.classList[c]) for c in cindex]
del cindex
batch=[self.dataset[i] for i in index]
del index
self.current_position+=self.batch_size
if self.current_position >= N:
if self._repeat:
self.current_position = 0
else:
self.current_position = N
self.epoch += 1
self.is_new_epoch = True
else:
self.is_new_epoch = False
return batch
next = __next__
@property
def epoch_detail(self):
return self.epoch + self.current_position / len(self.dataset)
def serialize(self, serializer):
self.current_position = serializer('current_position',
self.current_position)
self.epoch = serializer('epoch', self.epoch)
self.is_new_epoch = serializer('is_new_epoch', self.is_new_epoch)
if self._order is not None:
serializer('_order', self._order)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment