Skip to content

Instantly share code, notes, and snippets.

@lambdaofgod
Last active February 13, 2020 11:17
Show Gist options
  • Save lambdaofgod/fa88ec92deab4fe10fa28f664e550658 to your computer and use it in GitHub Desktop.
Save lambdaofgod/fa88ec92deab4fe10fa28f664e550658 to your computer and use it in GitHub Desktop.
import torch
import ot
from sklearn import metrics
roberta = torch.hub.load('pytorch/fairseq', 'roberta.large')
roberta.eval() # disable dropout (or leave in train mode to finetune)
def get_roberta_features(text):
return roberta.extract_features(roberta.encode(text)).detach().numpy()[0]
def get_optimal_transport_distance(v, w, sinkhorn_method='sinkhorn', entropy_regularization=0.05):
dists = metrics.pairwise.euclidean_distances(v, w) ** 2
transport = ot.sinkhorn([], [], dists, reg=entropy_regularization, method=sinkhorn_method, numItermax=100)
cost = (dists * transport).sum()
return cost
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment