Skip to content

Instantly share code, notes, and snippets.

@gaphex
Last active February 29, 2020 14:55
Show Gist options
  • Select an option

  • Save gaphex/61522b6e35b786685c6252daa75ed2e1 to your computer and use it in GitHub Desktop.

Select an option

Save gaphex/61522b6e35b786685c6252daa75ed2e1 to your computer and use it in GitHub Desktop.
class TripletGenerator:
def __init__(self, datadict, hard_frac = 0.2, batch_size=256):
self.datadict = datadict
self._anchor_idx = np.array(list(self.datadict.keys()))
self._hard_frac = hard_frac
self._generator = self.generate_batch(batch_size)
def generate_batch(self, size):
while True:
hards = int(size*self._hard_frac)
anchor_ids = np.array(np.random.choice(self._anchor_idx, size, replace=False))
anchors = self.get_anchors(anchor_ids)
positives = self.get_positives(anchor_ids)
negatives = np.hstack([self.get_hard_negatives(anchor_ids[:hards]),
self.get_random_negatives(anchor_ids[hards:])])
labels = np.ones((size,1))
assert len(anchors) == len(positives) == len(negatives) == len(labels) == size
yield [anchors, positives, negatives], labels
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment