Skip to content

Instantly share code, notes, and snippets.

@SubhadityaMukherjee
Last active August 13, 2022 07:15
Show Gist options
  • Save SubhadityaMukherjee/722938d063ddcb975dc1b123e72912d6 to your computer and use it in GitHub Desktop.
Save SubhadityaMukherjee/722938d063ddcb975dc1b123e72912d6 to your computer and use it in GitHub Desktop.
fastaiv2predicts_full
# Assuming you have set up your Dataloader and learner as dls, learn
learn.fine_tune(1, wd=0.5)
learn.export("model.pkl") # Save the model
predictions_path = "../input/fruits/fruits-360_dataset/fruits-360/Test"
def predict_batch(self, item, rm_type_tfms=None, with_input=False): # this bit is slightly complicated. ignore it for now
dl = self.dls.test_dl(item, rm_type_tfms=rm_type_tfms, num_workers=15)
ret = self.get_preds(dl=dl, with_input=False, with_decoded=True)
return ret
import random
predictions_path = Path(predictions_path)
Learner.predict_batch = predict_batch
# Load the model
learn = load_learner("model.pkl")
tst_files = get_image_files(predictions_path) #same as before
tst_files = tst_files.shuffle()
# Run the predictions
preds = learn.predict_batch(tst_files)
classes = learn.dls.vocab # the original categories
# The following classes[int(x)], preds[2]) will need to be changed to preds[3] if you enable with_input on line 9
preds_mapped = list(map(lambda x: classes[int(x)], preds[2])) #just saving them out
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment