Created
July 24, 2020 06:43
-
-
Save akash-ch2812/87d55210503361eb1dfc85ac134fc105 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from keras.preprocessing.sequence import pad_sequences | |
from keras.utils import to_categorical | |
# generator function to generate inputs for model | |
def create_trianing_data(captions, images, tokenizer, max_caption_length, vocab_len, photos_per_batch): | |
X1, X2, y = list(), list(), list() | |
n=0 | |
# loop through every image | |
while 1: | |
for key, cap in captions.items(): | |
n+=1 | |
# retrieve the photo feature | |
image = images[key] | |
for c in cap: | |
# encode the sequence | |
sequnece = [tokenizer.word_index[word] for word in c.split(' ') if word in list(tokenizer.word_index.keys())] | |
# split one sequence into multiple X, y pairs | |
for i in range(1, len(sequence)): | |
# creating input, output | |
inp, out = sequence[:i], sequence[i] | |
# padding input | |
input_seq = pad_sequences([inp], maxlen=max_caption_length)[0] | |
# encode output sequence | |
output_seq = to_categorical([out], num_classes=vocab_len)[0] | |
# store | |
X1.append(image) | |
X2.append(input_seq) | |
y.append(output_seq) | |
# yield the batch data | |
if n==photos_per_batch: | |
yield ([np.array(X1), np.array(X2)], np.array(y)) | |
X1, X2, y = list(), list(), list() | |
n=0 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment