Skip to content

Instantly share code, notes, and snippets.

@vacmar01
Last active November 20, 2022 12:41
Show Gist options
  • Save vacmar01/eb32efe05569fed9a74b39ac629b7a2e to your computer and use it in GitHub Desktop.
Save vacmar01/eb32efe05569fed9a74b39ac629b7a2e to your computer and use it in GitHub Desktop.
A quick and dirty re-implemented a small subset of the fastai Interpretation class for computer vision
class Interpreter:
def __init__(self, model, dl):
self.model = model
self.dl = dl
if hasattr(model, "loss"):
self.loss_func = self.model.loss
self.losses = torch.empty(0)
self.model.eval()
for batch in tqdm.tqdm(self.dl):
x, y = batch
with torch.no_grad():
logits = self.model(x)
loss = self.loss_func(logits, y, reduction="none")
self.losses = torch.concat((self.losses, loss))
def top_losses(self, n=5):
return torch.topk(self.losses, n)
def plot_top_losses(self, n=5):
_, idxs = self.top_losses(n)
self._plot_indices(n,idxs)
def plot_top_losses_3d(self, n=5):
_, idxs = self.top_losses(n)
self._plot_indices_3d(n,idxs)
def plot_results(self, n=5):
idxs = torch.randperm(len(self.dl.dataset))[:n]
self._plot_indices(n,idxs)
def _plot_indices(self, n, idxs):
fig, axs = plt.subplots(ncols=n, figsize=(12,3))
for i,ax in enumerate(axs):
dsitem = self.dl.dataset[idxs[i]]
lossitem = self.losses[idxs[i]]
logit = self.model(dsitem[0].unsqueeze(0))
pred = torch.argmax(F.softmax(logit, dim=1), dim=1)
target = dsitem[1]
img = dsitem[0].squeeze()
ax.imshow(img)
ax.set_title(f"{lossitem:.4f} // {pred.item()} // {target}")
def _plot_indices_3d(self, n, idxs):
ncols=4
fig, axs = plt.subplots(nrows=n, ncols=ncols, figsize=(12,3*n))
for i in range(n):
axsi = axs[i]
dsitem = self.dl.dataset[idxs[i]]
lossitem = self.losses[idxs[i]]
logit = self.model(dsitem[0].unsqueeze(0))
pred = torch.argmax(F.softmax(logit, dim=1), dim=1)
target = dsitem[1]
img = dsitem[0].squeeze()
assert len(img.shape) == 3, "Input is not a 3d image volume, use `plot_top_losses()` instead."
z_s,_,_ = img.shape
slices = torch.linspace(0, 100, 4).round().int().tolist()
for j in range(ncols):
axsi[j].imshow(img[slices[j]])
axsi[j].set_title(f"{lossitem:.4f} // {pred.item()} // {target}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment