This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
encoder_model = Encoder() | |
decoder_model = CaptionModel(vocab_size).to(device) | |
decoder_model.load_state_dict(torch.load(args.checkpoint)) | |
for image_name in os.listdir("evaluate/images"): | |
image = load_image(image_name, size=224) | |
# convert the image pixels to a numpy array | |
image = transforms.ToTensor()(image) | |
# reshape data for the model | |
image = image.unsqueeze(0) | |
# prepare the image for the VGG model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def evaluate_model(model, descriptions, photos, tokenizer, max_length): | |
actual, predicted = list(), list() | |
# step over the whole set | |
for key, desc_list in descriptions.items(): | |
# generate description | |
yhat = generate_desc(model, tokenizer, photos[key], max_length) | |
# store actual and predicted | |
references = [d.split() for d in desc_list] | |
actual.append(references) | |
predicted.append(yhat.split()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# map an integer to a word | |
def word_for_id(integer, tokenizer): | |
for word, index in tokenizer.word_index.items(): | |
if index == integer: | |
return word | |
return None | |
# generate a description for an image | |
def generate_desc(model, tokenizer, photo, max_length): | |
# seed the generation process |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
model = CaptionModel(vocab_size).to(device) | |
loss_fn = nn.CrossEntropyLoss() | |
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) | |
if args.checkpoint != None: | |
print("Loading the checkpoint") | |
model.load_state_dict(torch.load(args.checkpoint)) | |
print("Number of epochs ", args.num_epochs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# create sequences of images, input sequences and output words for an image | |
def create_sequences(tokenizer, max_length, descriptions, photos, vocab_size): | |
X1, X2, y = list(), list(), list() | |
# walk through each image identifier | |
for key, desc_list in descriptions.items(): | |
# walk through each description for the image | |
for desc in desc_list: | |
# encode the sequence | |
seq = tokenizer.texts_to_sequences([desc])[0] | |
# split one sequence into multiple X,y pairs |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
X1, X2 (text sequence), y (word) | |
photo startseq, little | |
photo startseq, little, girl | |
photo startseq, little, girl, running | |
photo startseq, little, girl, running, in | |
photo startseq, little, girl, running, in, field | |
photo startseq, little, girl, running, in, field, endseq |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
X1, X2 (text sequence), y (word) | |
photo startseq, little | |
photo startseq, little, girl | |
photo startseq, little, girl, running | |
photo startseq, little, girl, running, in | |
photo startseq, little, girl, running, in, field | |
photo startseq, little, girl, running, in, field, endseq |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# convert a dictionary of clean descriptions to a list of descriptions | |
def to_lines(descriptions): | |
all_desc = list() | |
for key in descriptions.keys(): | |
[all_desc.append(d) for d in descriptions[key]] | |
return all_desc | |
# fit a tokenizer given caption descriptions | |
def create_tokenizer(descriptions): | |
lines = to_lines(descriptions) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# load photo features | |
def load_photo_features(filename, dataset): | |
# load all features | |
all_features = load(open(filename, 'rb')) | |
# filter features | |
features = {k: all_features[k] for k in dataset} | |
return features |
NewerOlder