Skip to content

Instantly share code, notes, and snippets.

@fehiepsi
Created November 29, 2018 18:26
Show Gist options
  • Save fehiepsi/b15ac2978f1045d6d96b1d35b640d742 to your computer and use it in GitHub Desktop.
Save fehiepsi/b15ac2978f1045d6d96b1d35b640d742 to your computer and use it in GitHub Desktop.
performance at mvn
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.distributions as dist\n",
"\n",
"class MultivariateNormalFixed(dist.MultivariateNormal):\n",
" def expand(self, batch_shape, _instance=None):\n",
" new = self._get_checked_instance(MultivariateNormalFixed, _instance)\n",
" batch_shape = torch.Size(batch_shape)\n",
" loc_shape = batch_shape + self.event_shape\n",
" cov_shape = batch_shape + self.event_shape + self.event_shape\n",
" new.loc = self.loc.expand(loc_shape)\n",
" # this is where to fix\n",
" new._unbroadcasted_scale_tril = self._unbroadcasted_scale_tril\n",
" if 'covariance_matrix' in self.__dict__:\n",
" new.covariance_matrix = self.covariance_matrix.expand(cov_shape)\n",
" if 'scale_tril' in self.__dict__:\n",
" new.scale_tril = self.scale_tril.expand(cov_shape)\n",
" if 'precision_matrix' in self.__dict__:\n",
" new.precision_matrix = self.precision_matrix.expand(cov_shape)\n",
" super(dist.MultivariateNormal, new).__init__(batch_shape,\n",
" self.event_shape,\n",
" validate_args=False)\n",
" new._validate_args = self._validate_args\n",
" return new"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"loc = torch.randn(3)\n",
"A = torch.randn(3, 3)\n",
"cov = A.matmul(A.t())\n",
"d1 = dist.MultivariateNormal(loc, cov).expand([1000, 3])\n",
"d2 = MultivariateNormalFixed(loc, cov).expand([1000, 3])\n",
"x = torch.randn(3)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"70.2 ms ± 3.09 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
]
}
],
"source": [
"%timeit d1.log_prob(x)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"488 µs ± 28.8 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n"
]
}
],
"source": [
"%timeit d2.log_prob(x)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.6"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment