Skip to content

Instantly share code, notes, and snippets.

@whilo
Created August 28, 2017 20:19
Show Gist options
  • Save whilo/2a6ca937a53b8a1403afde313c7f123a to your computer and use it in GitHub Desktop.
Save whilo/2a6ca937a53b8a1403afde313c7f123a to your computer and use it in GitHub Desktop.
SGHMC port from matlab.
import torch
import torch.nn.functional as F
import torch.autograd as autograd
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os
import math
from torch.autograd import Variable
def sghmc(U, m, dt, nstep, x, C, V):
# print("x: {}".format(x))
#x.grad.data.zero_()
p = torch.randn(x.data.shape[0], x.data.shape[1]) * math.sqrt(m)
B = 0.5 * V * dt
D = math.sqrt( 2 * (C-B) * dt )
for i in range(nstep):
U(x).backward()
gradU = x.grad.data + torch.randn(1) * 2
p = p - gradU * dt - p * C * dt + torch.randn(1) * D
x = Variable(x.data + p/m * dt, requires_grad=True)
return x
nsample = 80000
xStep = 0.1
m = 1
C = 3
dt = 0.1
nstep = 50
V = 4
# set random seed
torch.manual_seed(10)
#%% set up functions
#U = @(x) (-2* x.^2 + x.^4);
#gradU = @(x) ( -4* x + 4*x.^3) + randn(1) * 2;
#gradUPerfect = @(x) ( - 4*x + 4*x.^3 );
#fgname = 'figure/func4';
#hmccmp;
def U(x):
return - 2 * torch.pow(x, 2) + torch.pow(x, 4)
#U(x)
#x = torch.Tensor([[0.3, 0.2]])
samples = torch.zeros(nsample,1)
x = Variable(torch.Tensor([[0]]), requires_grad=True)
for i in range(nsample):
x = sghmc( U, m, dt, nstep, x, C, V )
#print("{}: {}".format(i, x))
samples[i] = x.data
xGrid = np.linspace(-3, 3, 60)
foo = plt.hist(samples.numpy(), xGrid)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment