Last active
March 23, 2018 00:31
-
-
Save akirayou/72a32e1df9a8415a1743a1dd63d24a59 to your computer and use it in GitHub Desktop.
Class Random choice iterator for chainer , Class is not Label. For example, It's for multi sample from one target object. this target's ID is class.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
""" | |
Created on Tue Sep 27 02:01:52 2016 | |
@author: a-noda | |
""" | |
import numpy as np | |
from chainer.dataset import iterator | |
class ClassRandomIterator(iterator.Iterator): | |
''' | |
random Choice clsid ,and select One realIdx from clsId | |
''' | |
def __init__(self, dataset,clsIds, batch_size, repeat=True,classMaskRate=1,classMaskIdx=None): | |
self.dataset = dataset | |
self.batch_size = batch_size | |
self._repeat = repeat | |
self.classNamesOrg,scid=np.unique(clsIds,return_inverse=True) | |
self.classInvIdx=scid | |
self.classList=[ np.nonzero(scid==c)[0] for c in range(np.max(scid)+1) ] | |
self.classListOrg=self.classList.copy() | |
#print("nof type of class (org)",len(self.classList)) | |
self.remask(classMaskRate,classMaskIdx) | |
self.current_position = 0 | |
self.epoch = 0 | |
self.is_new_epoch = False | |
def rewind(self): | |
self.current_position = 0 | |
self.epoch = 0 | |
self.is_new_epoch = False | |
def clsIdToInvalidIdx(self,clsIds): | |
ret= (clsIds != self.classNames[0]) | |
for c in self.classNames[1:]: | |
ret *= (clsIds!=c) | |
return np.nonzero(ret)[0] | |
def remask(self,classMaskRate=1,classMaskIdx=None): | |
self.classList=self.classListOrg.copy() | |
if(classMaskRate != 1 or (not classMaskIdx is None) ): | |
if(classMaskIdx is None): | |
self.classMaskIdx=np.random.choice(range(len(self.classList)), int(len(self.classList)*classMaskRate) ) | |
else: | |
self.classMaskIdx=classMaskIdx | |
self.classList=[self.classList[i] for i in self.classMaskIdx] | |
self.classNames=self.classNamesOrg[self.classMaskIdx] | |
print("remask ",len(self.classList)) | |
else: | |
self.classNames=self.classNamesOrg | |
def __next__(self): | |
if not self._repeat and self.epoch > 0: | |
raise StopIteration | |
N = len(self.dataset) | |
cindex=np.random.choice(len(self.classList),self.batch_size) | |
index=[np.random.choice(self.classList[c]) for c in cindex] | |
del cindex | |
batch=[self.dataset[i] for i in index] | |
del index | |
self.current_position+=self.batch_size | |
if self.current_position >= N: | |
if self._repeat: | |
self.current_position = 0 | |
else: | |
self.current_position = N | |
self.epoch += 1 | |
self.is_new_epoch = True | |
else: | |
self.is_new_epoch = False | |
return batch | |
next = __next__ | |
@property | |
def epoch_detail(self): | |
return self.epoch + self.current_position / len(self.dataset) | |
def serialize(self, serializer): | |
self.current_position = serializer('current_position', | |
self.current_position) | |
self.epoch = serializer('epoch', self.epoch) | |
self.is_new_epoch = serializer('is_new_epoch', self.is_new_epoch) | |
if self._order is not None: | |
serializer('_order', self._order) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment