Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save indiejoseph/fe6abdfe2a0625f3185a1a573445da80 to your computer and use it in GitHub Desktop.
Save indiejoseph/fe6abdfe2a0625f3185a1a573445da80 to your computer and use it in GitHub Desktop.
Whisper perplexity
import torchaudio
def eval(audio, text):
# convert audio to 16000 sample rate
audio = torchaudio.transforms.Resample(orig_freq=44100, new_freq=16000)(torch.tensor(audio).unsqueeze(0)).squeeze()
# process text
tokenized_seq = torch.tensor([processor.tokenizer(text, add_special_tokens=True).input_ids]).to(device)
decoder_input_ids = tokenized_seq[:, 1:]
decoder_input_ids_right_shifted = tokenized_seq[:, :-1]
# process audio
processed_in = processor(audio, sampling_rate=16000, return_tensors="pt").to(device)
with torch.no_grad():
output = model.forward(input_features=processed_in.input_features, decoder_input_ids=decoder_input_ids_right_shifted)
# Convert logits to log-probabilities:
log_prob_all = torch.nn.functional.log_softmax(output.logits, dim=-1)
# Take probabilities for the ground-truth tokens:
log_prob = log_prob_all.take_along_dim(decoder_input_ids[..., None], dim=-1)
# Compute perplexity:
perplexity = torch.exp(-log_prob.mean())
return perplexity.item()
eval(audio, text)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment