Last active
February 1, 2022 09:27
-
-
Save yuq-1s/8bf91eaac76bbb5d6997eb36043ea1f8 to your computer and use it in GitHub Desktop.
[Implicit Maximum Likelihood Estimation](https://arxiv.org/abs/1809.09087) in 100 lines
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import mxnet as mx | |
from mxnet import nd, autograd | |
from mxnet.gluon import nn | |
from mxnet.gluon.contrib.nn import Identity, Concurrent | |
from mxnet import gluon | |
import logging | |
def d(a, b): | |
return (a - b).norm() | |
def R(a, b): | |
return [min([d(a0, b0) for b0 in b]) for a0 in a] | |
def loss(fake, x): | |
R_value = R(x, fake) | |
return sum(R_value) / len(R_value) | |
def visualize(x): | |
''' | |
x: [n, dim_x] | |
''' | |
assert x.shape[1] == 2 | |
import matplotlib.pyplot as plt | |
x = x.asnumpy() | |
plt.scatter(x[:, 0], x[:, 1]) | |
plt.show(block=False) | |
class Model(nn.Block): | |
def __init__(self, **kwargs): | |
super().__init__(**kwargs) | |
self.blks = [] | |
self.input = nn.Sequential() | |
with self.input.name_scope(): | |
self.input.add(nn.Dense(5), nn.BatchNorm(), nn.LeakyReLU(0.1)) | |
self.input.initialize('orthogonal') | |
for i in range(10): | |
blk = nn.Sequential() | |
with blk.name_scope(): | |
blk.add(nn.Dense(5), | |
nn.BatchNorm(), | |
nn.LeakyReLU(0.1)) | |
self.register_child(blk) | |
blk.initialize('orthogonal') | |
self.blks.append(blk) | |
self.output = nn.Dense(2) | |
self.output.initialize('orthogonal') | |
def forward(self, z): | |
z = self.input(z) | |
for blk in self.blks: | |
z = blk(z) + z | |
return self.output(z) | |
def get_x(n=30): | |
''' | |
return [3*n, 2] | |
sample data of dimension 2 | |
''' | |
functions = [nd.sin, lambda x: nd.abs(1-x*x), lambda x: -nd.cos(x)] | |
def gen(): | |
for f in functions: | |
x1 = nd.random.uniform(-1, 1, n) | |
x2 = f(x1) + 0.05 * nd.random.uniform(-1, 1, n) | |
yield nd.transpose(nd.stack(x1, x2)) | |
# x2 = nd.stack(*[f(x1) for f in choice]) | |
return nd.shuffle(nd.reshape(nd.stack(*list(gen())), [len(functions)*n, 2])) | |
if __name__ == '__main__': | |
log = logging.getLogger() | |
log.setLevel(logging.DEBUG) | |
n = 30 | |
x_dim = 2 | |
z_dim = 1 | |
steps = 20000 | |
model = Model() | |
schedule = mx.lr_scheduler.MultiFactorScheduler(step=[2000, 7500, 10000], factor=0.5) | |
sgd_optimizer = mx.optimizer.Adam(learning_rate=0.03, lr_scheduler=schedule) | |
trainer = mx.gluon.Trainer(params=model.collect_params(), optimizer=sgd_optimizer) | |
x = get_x(n) | |
train_loss = 0.; | |
logging.debug("start training") | |
for step in range(steps): | |
with autograd.record(): | |
z = nd.random.uniform(-1, 1, (n, z_dim)) | |
fake = model(z) | |
train_loss = loss(fake=fake, x=x) | |
train_loss.backward() | |
trainer.step(n) | |
print("step: {}, train_loss: {}".format(step, train_loss)) | |
if step % 100 == 0: | |
visualize(fake) | |
x = get_x(n) | |
visualize(x) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment