Last active
November 1, 2020 16:35
-
-
Save vene/ebd75aa06b39568b1eec9c69ca98a56c to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# author: vlad niculae <[email protected]> | |
# license: mit | |
import torch | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import matplotlib.colors as colors | |
from entmax import sparsemax, entmax15 | |
from entmax.losses import sparsemax_loss, entmax15_loss | |
# unify API for output layers: softmax, sparsemax, entmax. | |
def _extend_2d(z): | |
z = z.unsqueeze(dim=-1) | |
z = torch.cat((torch.zeros_like(z), z), dim=-1) | |
return z | |
class SoftmaxOut(object): | |
def __init__(self): | |
self.loss_obj = torch.nn.BCEWithLogitsLoss() | |
def loss(self, z, y_true): | |
return self.loss_obj(z, y_true) | |
def yhat(self, z): | |
return torch.sigmoid(z) | |
class SparsemaxOut(object): | |
def loss(self, z, y_true): | |
return sparsemax_loss(_extend_2d(z), y_true.long()).mean() | |
def yhat(self, z): | |
return sparsemax(_extend_2d(z))[..., 1] | |
class Entmax15Out(object): | |
def loss(self, z, y_true): | |
return entmax15_loss(_extend_2d(z), y_true.long()).mean() | |
def yhat(self, z): | |
return entmax15(_extend_2d(z))[..., 1] | |
# evaluate a function over a mesh for making contour plots | |
class MeshEval(object): | |
def __init__(self, xlim, ylim, n_points): | |
self.n_points = n_points | |
x_min, x_max = xlim | |
y_min, y_max = ylim | |
grid_x = np.linspace(x_min, x_max, n_points) | |
grid_y = np.linspace(y_min, y_max, n_points) | |
mesh_x, mesh_y = np.meshgrid(grid_x, grid_y) | |
grid_pts = np.column_stack([mesh_x.ravel(), mesh_y.ravel()]) | |
self.mesh_x = mesh_x | |
self.mesh_y = mesh_y | |
self.grid_pts = torch.from_numpy(grid_pts).float() | |
def __call__(self, net): | |
z = (net(self.grid_pts) | |
.reshape(self.n_points, self.n_points) | |
.detach()) | |
return z | |
OUT_LAYERS = { | |
'softmax': SoftmaxOut(), | |
'sparsemax': SparsemaxOut(), | |
'entmax15': Entmax15Out() | |
} | |
def train(net, X, y, out, n_epochs, callback): | |
optim = torch.optim.SGD(params=net.parameters(), lr=0.05, | |
momentum=.9, nesterov=True) | |
for it in range(n_epochs): | |
optim.zero_grad() | |
z = net(X).squeeze() | |
loss = out.loss(z, y) | |
callback(it, z.detach(), loss.item(), net) | |
loss.backward() | |
optim.step() | |
def main(out_name): | |
# xor data | |
hid = 100 | |
n_epochs = 300 | |
torch.manual_seed(42) | |
def sample_X_y(batch_size): | |
X = 2 * torch.rand(size=(batch_size, 2)) - 1 | |
y = (torch.prod(X, dim=1) > 0).float() | |
return X, y | |
X_train, y_train = sample_X_y(batch_size=300) | |
net = torch.nn.Sequential( | |
torch.nn.Linear(2, hid), | |
torch.nn.ReLU(), | |
torch.nn.Linear(hid, 1)) | |
its = [] | |
loss_vals = [] | |
acc_vals = [] | |
mesh_eval = MeshEval(xlim=(-1.5, 1.5), | |
ylim=(-1.5, 1.5), | |
n_points=50) | |
out = OUT_LAYERS[out_name] | |
def callback(it, y_pred, loss, net): | |
accuracy = torch.mean(((y_pred > 0) == (y_train > 0)).double()) | |
its.append(it) | |
loss_vals.append(loss) | |
acc_vals.append(accuracy) | |
print("Iter {:3d} Loss {:.3f} Acc {:.3f}".format( | |
it, | |
loss, | |
accuracy | |
)) | |
mesh_z = mesh_eval(net) | |
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(14, 4), tight_layout=True) | |
# plot decision function (z) | |
max_val = np.abs(mesh_z).max() | |
divnorm = colors.DivergingNorm(vmin=-max_val, vcenter=0, vmax=max_val) | |
contour = ax1.contourf(mesh_eval.mesh_x, | |
mesh_eval.mesh_y, | |
mesh_z, | |
levels=np.linspace(-max_val, max_val, 30), | |
norm=divnorm, | |
cmap=plt.cm.PuOr) | |
ax1.axvline(0, ls=":", color="k") | |
ax1.axhline(0, ls=":", color="k") | |
ax1.set_title('z') | |
plt.colorbar(contour, ax=ax1) | |
# plot positive class probability sigma(z) | |
ax2.set_title('$\\sigma(z)$') | |
divnorm = colors.DivergingNorm(vmin=0, vcenter=0.5, vmax=1) | |
contour = ax2.contourf(mesh_eval.mesh_x, | |
mesh_eval.mesh_y, | |
out.yhat(mesh_z), | |
levels=np.linspace(0, 1, 30), | |
norm=divnorm, | |
cmap=plt.cm.PuOr) | |
ax2.axvline(0, ls=":", color="k") | |
ax2.axhline(0, ls=":", color="k") | |
plt.colorbar(contour, ax=ax2) | |
# plot training loss value | |
ax3.plot(its, loss_vals) | |
ax3.set_xlim(-1, n_epochs + 1) | |
ax3.set_ylim(0, loss_vals[0]) | |
ax3.set_xlabel("iteration") | |
ax3.set_title("loss") | |
plt.suptitle(f"{out_name} Iter {it:03d} Accuracy {accuracy * 100:0.02f}") | |
plt.savefig(f"{out_name}_{it:03d}.png") | |
plt.close(fig) | |
train(net, X_train, y_train, out, n_epochs, callback) | |
if __name__ == '__main__': | |
main('softmax') | |
main('sparsemax') | |
main('entmax15') |
Author
vene
commented
Nov 1, 2020
•
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment