Skip to content

Instantly share code, notes, and snippets.

@SubhadityaMukherjee
Created August 12, 2022 17:34
Show Gist options
  • Save SubhadityaMukherjee/6c994868d215f9c0961f887d7e1d2b32 to your computer and use it in GitHub Desktop.
Save SubhadityaMukherjee/6c994868d215f9c0961f887d7e1d2b32 to your computer and use it in GitHub Desktop.
runmo
predictions_path = "../input/fruits/fruits-360_dataset/fruits-360/Test"
def predict_batch(self, item, rm_type_tfms=None, with_input=False): # this bit is slightly complicated. ignore it for now
dl = self.dls.test_dl(item, rm_type_tfms=rm_type_tfms, num_workers=15)
ret = self.get_preds(dl=dl, with_input=False, with_decoded=True)
return ret
import random
predictions_path = Path(predictions_path)
Learner.predict_batch = predict_batch
# This is important
learn = load_learner("model.pkl")
tst_files = get_image_files(predictions_path) #same as before
tst_files = tst_files.shuffle()
# Now running the predictions.
preds = learn.predict_batch(tst_files)
classes = learn.dls.vocab # the original categories
preds_mapped = list(map(lambda x: classes[int(x)], preds[2])) #just saving them out
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment