Skip to content

Instantly share code, notes, and snippets.

@hamelsmu
Created February 20, 2022 07:15
Show Gist options
  • Save hamelsmu/69e702d2f8bcaf3af011dfa478329511 to your computer and use it in GitHub Desktop.
Save hamelsmu/69e702d2f8bcaf3af011dfa478329511 to your computer and use it in GitHub Desktop.
How to make batch predictions in fastai
@patch
def predict_batch(self:Learner, item, rm_type_tfms=None, with_input=False):
dl = self.dls.test_dl(item, rm_type_tfms=rm_type_tfms, num_workers=0)
inp,preds,_,dec_preds = self.get_preds(dl=dl, with_input=True, with_decoded=True)
i = getattr(self.dls, 'n_inp', -1)
inp = (inp,) if i==1 else tuplify(inp)
dec_inp, nm = zip(*self.dls.decode_batch(inp + tuplify(dec_preds)))
res = preds,nm,dec_preds
if with_input: res = (dec_inp,) + res
return res
@hamelsmu
Copy link
Author

This monkey patches a new function predict_batch onto learner

>>> from fastai.text.all import *
>>> from predict_batch import predict_batch # this file.  If you don't import just define in your script.
>>> dls = TextDataLoaders.from_folder(untar_data(URLs.IMDB), valid='test')
>>> learn = text_classifier_learner(dls, AWD_LSTM, drop_mult=0.5, metrics=accuracy)
>>> learn.fine_tune(4, 1e-2)
>>> learn.predict_batch(["hello world"]*4)
(TensorText([[0.0029, 0.9971],
         [0.0029, 0.9971],
         [0.0029, 0.9971],
         [0.0029, 0.9971]]),
 ('pos', 'pos', 'pos', 'pos'),
 TensorText([1, 1, 1, 1]))

@hamelsmu
Copy link
Author

hamelsmu commented Feb 20, 2022

Actually, this prediction method works for both:

def predict(self, item, rm_type_tfms=None, with_input=False):
    dl = self.dls.test_dl(item, rm_type_tfms=rm_type_tfms, num_workers=0)
    inp,preds,_,dec_preds = self.get_preds(dl=dl, with_input=True, with_decoded=True)
    i = getattr(self.dls, 'n_inp', -1)
    inp = (inp,) if i==1 else tuplify(inp)
    dec = self.dls.decode_batch(inp + tuplify(dec_preds))
    dec_inp,dec_targ = (tuple(map(detuplify, d)) for d in zip(*dec.map(lambda x: (x[:i], x[i:]))))
    res = dec_targ,dec_preds,preds
    if with_input: res = (dec_inp,) + res
    return res

@hamelsmu
Copy link
Author

Other notes h/t zach:

learn.dls.vocab or learn.dls.categorize.vocab is another way to get the class names.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment