Skip to content

Instantly share code, notes, and snippets.

@hrzn
Created February 8, 2020 11:14
Show Gist options
  • Save hrzn/38914e0f2d557b7959285258b2849bcf to your computer and use it in GitHub Desktop.
Save hrzn/38914e0f2d557b7959285258b2849bcf to your computer and use it in GitHub Desktop.
class QNet(nn.Module):
""" The neural net that will parameterize the function Q(s, a)
The input is the state (containing the graph and visited nodes),
and the output is a vector of size N containing Q(s, a) for each of the N actions a.
"""
def __init__(self, emb_dim, T=4):
""" emb_dim: embedding dimension p
T: number of iterations for the graph embedding
"""
super(QNet, self).__init__()
self.emb_dim = emb_dim
self.T = T
# We use 5 dimensions for representing the nodes' states:
# * A binary variable indicating whether the node has been visited
# * A binary variable indicating whether the node is the first of the visited sequence
# * A binary variable indicating whether the node is the last of the visited sequence
# * The (x, y) coordinates of the node.
self.node_dim = 5
# We can have an extra layer after theta_1 (for the sake of example to make the network deeper)
nr_extra_layers_1 = 1
# Build the learnable affine maps:
self.theta1 = nn.Linear(self.node_dim, self.emb_dim, True)
self.theta2 = nn.Linear(self.emb_dim, self.emb_dim, True)
self.theta3 = nn.Linear(self.emb_dim, self.emb_dim, True)
self.theta4 = nn.Linear(1, self.emb_dim, True)
self.theta5 = nn.Linear(2*self.emb_dim, 1, True)
self.theta6 = nn.Linear(self.emb_dim, self.emb_dim, True)
self.theta7 = nn.Linear(self.emb_dim, self.emb_dim, True)
self.theta1_extras = [nn.Linear(self.emb_dim, self.emb_dim, True) for _ in range(nr_extra_layers_1)]
def forward(self, xv, Ws):
# xv: The node features (batch_size, num_nodes, node_dim)
# Ws: The graphs (batch_size, num_nodes, num_nodes)
num_nodes = xv.shape[1]
batch_size = xv.shape[0]
# pre-compute 1-0 connection matrices masks (batch_size, num_nodes, num_nodes)
conn_matrices = torch.where(Ws > 0, torch.ones_like(Ws), torch.zeros_like(Ws)).to(device)
# Graph embedding
# Note: we first compute s1 and s3 once, as they are not dependent on mu
mu = torch.zeros(batch_size, num_nodes, self.emb_dim, device=device)
s1 = self.theta1(xv) # (batch_size, num_nodes, emb_dim)
for layer in self.theta1_extras:
s1 = layer(F.relu(s1)) # we apply the extra layer
s3_1 = F.relu(self.theta4(Ws.unsqueeze(3))) # (batch_size, nr_nodes, nr_nodes, emb_dim) - each "weigth" is a p-dim vector
s3_2 = torch.sum(s3_1, dim=1) # (batch_size, nr_nodes, emb_dim) - the embedding for each node
s3 = self.theta3(s3_2) # (batch_size, nr_nodes, emb_dim)
for t in range(self.T):
s2 = self.theta2(conn_matrices.matmul(mu))
mu = F.relu(s1 + s2 + s3)
""" prediction
"""
# we repeat the global state (summed over nodes) for each node,
# in order to concatenate it to local states later
global_state = self.theta6(torch.sum(mu, dim=1, keepdim=True).repeat(1, num_nodes, 1))
local_action = self.theta7(mu) # (batch_dim, nr_nodes, emb_dim)
out = F.relu(torch.cat([global_state, local_action], dim=2))
return self.theta5(out).squeeze(dim=2)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment