Skip to content

Instantly share code, notes, and snippets.

@dedcode
Last active December 6, 2023 08:30
Show Gist options
  • Save dedcode/ca81d0f344ac9bece144f93f7853fd3a to your computer and use it in GitHub Desktop.
Save dedcode/ca81d0f344ac9bece144f93f7853fd3a to your computer and use it in GitHub Desktop.
@torch.no_grad()
def buildSubgraph(self, head_index: Tensor, rel_type: Tensor, tail_index: Tensor) -> Tensor:
k = self.neighbors_topk # Number of top tails to find
M = self.neighbors_size
# Initialize a dictionary to store unique objects the triples
unique_nodes = {(h.item(), r.item(), t.item()): set()
for h, r, t in zip(head_index, rel_type, tail_index)}
# local utility that uses the KGE to find the nearest nodes.
def batch_predict(model, s, p, o, mode = 'head'):
with torch.no_grad():
if mode == 'head':
# Expand the head and relation tensors to match each possible tail
s_expanded = s.view(-1, 1).expand(-1, self.num_nodes).reshape(-1)
p_expanded = p.view(-1, 1).expand(-1, self.num_nodes).reshape(-1)
o_expanded = o.repeat(len(s))
if mode == 'tail':
# Expand the tail and relation tensors to match each possible head
s_expanded = s.repeat(len(o))
p_expanded = p.view(-1, 1).expand(-1, self.num_nodes).reshape(-1)
o_expanded = o.view(-1, 1).expand(-1, self.num_nodes).reshape(-1)
#print(s_expanded.shape, p_expanded.shape, o_expanded.shape)
# Ensure all tensors are on the same device as the model
#s_expanded, p_expanded, o_expanded = s_expanded.to(device), p_expanded.to(device), o_expanded.to(device)
# Compute scores
scores = model(s_expanded, p_expanded, o_expanded)
probabilities = torch.sigmoid(scores).squeeze()
return probabilities
# head to all
for rel_id in range(self.num_relations):
rel_tensor = torch.tensor([rel_id] * len(head_index), dtype=torch.long, device=head_index.device)
tail_tensor = torch.arange(self.num_nodes, dtype=torch.long, device=head_index.device)
# Compute scores for all tails for the current relation
probabilities = batch_predict(self.kge_model, head_index, rel_tensor, tail_tensor, mode = 'head')
# Reshape and find the top k tails
probabilities = probabilities.view(len(head_index), self.num_nodes)
top_tails_indices = torch.topk(probabilities, k, dim=1).indices
for i, (h, r, t) in enumerate(zip(head_index, rel_type, tail_index)):
triple_key = (h.item(), r.item(), t.item())
unique_nodes[triple_key].update(top_tails_indices[i].tolist())
# all to head
for rel_id in range(self.num_relations):
head_tensor = torch.arange(self.num_nodes, dtype=torch.long, device=head_index.device)
rel_tensor = torch.tensor([rel_id] * len(head_index), dtype=torch.long, device=head_index.device)
# Compute scores for all tails for the current relation
probabilities = batch_predict(self.kge_model, head_tensor, rel_tensor, head_index, mode = 'tail')
# Reshape and find the top k tails
probabilities = probabilities.view(len(head_index), self.num_nodes)
top_tails_indices = torch.topk(probabilities, k, dim=1).indices
for i, (h, r, t) in enumerate(zip(head_index, rel_type, tail_index)):
triple_key = (h.item(), r.item(), t.item())
unique_nodes[triple_key].update(top_tails_indices[i].tolist())
# all to tail
for rel_id in range(self.num_relations):
head_tensor = torch.arange(self.num_nodes, dtype=torch.long, device=head_index.device)
rel_tensor = torch.tensor([rel_id] * len(tail_index), dtype=torch.long, device=head_index.device)
# Compute scores for all tails for the current relation
probabilities = batch_predict(self.kge_model, head_tensor, rel_tensor, tail_index, mode = 'tail')
# Reshape and find the top k tails
probabilities = probabilities.view(len(tail_index), self.num_nodes)
top_tails_indices = torch.topk(probabilities, k, dim=1).indices
for i, (h, r, t) in enumerate(zip(head_index, rel_type, tail_index)):
triple_key = (h.item(), r.item(), t.item())
unique_nodes[triple_key].update(top_tails_indices[i].tolist())
# tail to all
for rel_id in range(self.num_relations):
rel_tensor = torch.tensor([rel_id] * len(tail_index), dtype=torch.long, device=head_index.device)
tail_tensor = torch.arange(self.num_nodes, dtype=torch.long, device=head_index.device)
# Compute scores for all tails for the current relation
probabilities = batch_predict(self.kge_model, tail_index, rel_tensor, tail_tensor, mode = 'head')
# Reshape and find the top k tails
probabilities = probabilities.view(len(tail_index), self.num_nodes)
top_tails_indices = torch.topk(probabilities, k, dim=1).indices
for i, (h, r, t) in enumerate(zip(head_index, rel_type, tail_index)):
triple_key = (h.item(), r.item(), t.item())
unique_nodes[triple_key].update(top_tails_indices[i].tolist())
# Ensure the number of unique tails does not exceed M and pad if necessary
final_neighbors_per_triple = {}
for triple, neighbors in unique_nodes.items():
neighbors.discard(triple[0]) # remove the head (it will be added later)
neighbors.discard(triple[2]) # remove the tail (it will be added later)
if len(neighbors) > M:
# sample if more nodes than M were retrieved
neighbors = list(random.sample(neighbors, M))
elif len(neighbors) < M:
while len(neighbors) < M:
rnd_neighbor = random.randint(0, self.num_nodes - 1)
if rnd_neighbor not in neighbors:
neighbors.add(rnd_neighbor)
neighbors = list(neighbors)
final_neighbors_per_triple[triple] = list(neighbors)
# Creating the Neighbors indices Tensor
neighbors_indices = torch.tensor([final_neighbors_per_triple[triple] for triple in final_neighbors_per_triple], dtype=torch.long, device=head_index.device)
# Adding Necessary Data
# Adding all Relations
fixed_relation_list = list(self.mapping_dict['relation'].values())
fixed_relation_tensor = torch.tensor(fixed_relation_list, dtype=torch.long, device=head_index.device)
fixed_relation_tensor_repeated = fixed_relation_tensor.unsqueeze(0).repeat(neighbors_indices.size(0), 1)
# Concatenate relations tensors
concatenated_tensor = torch.cat((fixed_relation_tensor_repeated, neighbors_indices), dim=1)
# Adding Head/Tail indices
# Reshape head_index, rel_type, and tail_index to [6, 1]
head_index = head_index.view(-1, 1)
#rel_type = rel_type.view(-1, 1) # the relation vector is not needed, it has already been added
tail_index = tail_index.view(-1, 1)
# Concatenate tensors
concatenated_tensor = torch.cat((head_index, tail_index, concatenated_tensor), dim=1)
# Assuming concatenated is a 2D tensor, with each row being a vector of indices
batch_size = concatenated_tensor.size(0)
# Extract indices for the entire batch
# For now this is hardcoded. Wordnet has 11 relations
headtail_indices = concatenated_tensor[:, :2]
relation_indices = concatenated_tensor[:, 2:13]
neighbors_indices = concatenated_tensor[:, 13:]
# Flatten the indices for mapping and later reshape back to batch format
flat_headtail_indices = headtail_indices.reshape(-1)
flat_relation_indices = relation_indices.reshape(-1)
flat_neighbors_indices = neighbors_indices.reshape(-1)
# Map indices to embeddings
flat_headtail_real_embeddings = self.combined_embedding[torch.tensor([self.mapping_dict['entity']['real'][idx.item()] for idx in flat_headtail_indices])]
flat_headtail_imaginary_embeddings = self.combined_embedding[torch.tensor([self.mapping_dict['entity']['imaginary'][idx.item()] for idx in flat_headtail_indices])]
flat_neighbors_real_embeddings = self.combined_embedding[torch.tensor([self.mapping_dict['entity']['real'][idx.item()] for idx in flat_neighbors_indices])]
flat_neighbors_imaginary_embeddings = self.combined_embedding[torch.tensor([self.mapping_dict['entity']['imaginary'][idx.item()] for idx in flat_neighbors_indices])]
relation_embeddings = self.combined_embedding[flat_relation_indices]
# Reshape embeddings back to batch format
headtail_real_embeddings = flat_headtail_real_embeddings.view(len(head_index), headtail_indices.size(1), self.input_dim)
headtail_imaginary_embeddings = flat_headtail_imaginary_embeddings.view(len(head_index), headtail_indices.size(1), self.input_dim)
neighbors_real_embeddings = flat_neighbors_real_embeddings.view(len(head_index), neighbors_indices.size(1), self.input_dim)
neighbors_imaginary_embeddings = flat_neighbors_imaginary_embeddings.view(len(head_index), neighbors_indices.size(1), self.input_dim)
relation_embeddings = relation_embeddings.view(len(head_index), relation_indices.size(1), self.input_dim)
# Concatenate to form the final embeddings tensor for the batch
final_embeddings_batch = torch.cat((headtail_real_embeddings,
headtail_imaginary_embeddings,
relation_embeddings,
neighbors_real_embeddings,
neighbors_imaginary_embeddings), dim=1)
return final_embeddings_batch, neighbors_indices
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment