Last active
December 6, 2023 08:30
-
-
Save dedcode/ca81d0f344ac9bece144f93f7853fd3a to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@torch.no_grad() | |
def buildSubgraph(self, head_index: Tensor, rel_type: Tensor, tail_index: Tensor) -> Tensor: | |
k = self.neighbors_topk # Number of top tails to find | |
M = self.neighbors_size | |
# Initialize a dictionary to store unique objects the triples | |
unique_nodes = {(h.item(), r.item(), t.item()): set() | |
for h, r, t in zip(head_index, rel_type, tail_index)} | |
# local utility that uses the KGE to find the nearest nodes. | |
def batch_predict(model, s, p, o, mode = 'head'): | |
with torch.no_grad(): | |
if mode == 'head': | |
# Expand the head and relation tensors to match each possible tail | |
s_expanded = s.view(-1, 1).expand(-1, self.num_nodes).reshape(-1) | |
p_expanded = p.view(-1, 1).expand(-1, self.num_nodes).reshape(-1) | |
o_expanded = o.repeat(len(s)) | |
if mode == 'tail': | |
# Expand the tail and relation tensors to match each possible head | |
s_expanded = s.repeat(len(o)) | |
p_expanded = p.view(-1, 1).expand(-1, self.num_nodes).reshape(-1) | |
o_expanded = o.view(-1, 1).expand(-1, self.num_nodes).reshape(-1) | |
#print(s_expanded.shape, p_expanded.shape, o_expanded.shape) | |
# Ensure all tensors are on the same device as the model | |
#s_expanded, p_expanded, o_expanded = s_expanded.to(device), p_expanded.to(device), o_expanded.to(device) | |
# Compute scores | |
scores = model(s_expanded, p_expanded, o_expanded) | |
probabilities = torch.sigmoid(scores).squeeze() | |
return probabilities | |
# head to all | |
for rel_id in range(self.num_relations): | |
rel_tensor = torch.tensor([rel_id] * len(head_index), dtype=torch.long, device=head_index.device) | |
tail_tensor = torch.arange(self.num_nodes, dtype=torch.long, device=head_index.device) | |
# Compute scores for all tails for the current relation | |
probabilities = batch_predict(self.kge_model, head_index, rel_tensor, tail_tensor, mode = 'head') | |
# Reshape and find the top k tails | |
probabilities = probabilities.view(len(head_index), self.num_nodes) | |
top_tails_indices = torch.topk(probabilities, k, dim=1).indices | |
for i, (h, r, t) in enumerate(zip(head_index, rel_type, tail_index)): | |
triple_key = (h.item(), r.item(), t.item()) | |
unique_nodes[triple_key].update(top_tails_indices[i].tolist()) | |
# all to head | |
for rel_id in range(self.num_relations): | |
head_tensor = torch.arange(self.num_nodes, dtype=torch.long, device=head_index.device) | |
rel_tensor = torch.tensor([rel_id] * len(head_index), dtype=torch.long, device=head_index.device) | |
# Compute scores for all tails for the current relation | |
probabilities = batch_predict(self.kge_model, head_tensor, rel_tensor, head_index, mode = 'tail') | |
# Reshape and find the top k tails | |
probabilities = probabilities.view(len(head_index), self.num_nodes) | |
top_tails_indices = torch.topk(probabilities, k, dim=1).indices | |
for i, (h, r, t) in enumerate(zip(head_index, rel_type, tail_index)): | |
triple_key = (h.item(), r.item(), t.item()) | |
unique_nodes[triple_key].update(top_tails_indices[i].tolist()) | |
# all to tail | |
for rel_id in range(self.num_relations): | |
head_tensor = torch.arange(self.num_nodes, dtype=torch.long, device=head_index.device) | |
rel_tensor = torch.tensor([rel_id] * len(tail_index), dtype=torch.long, device=head_index.device) | |
# Compute scores for all tails for the current relation | |
probabilities = batch_predict(self.kge_model, head_tensor, rel_tensor, tail_index, mode = 'tail') | |
# Reshape and find the top k tails | |
probabilities = probabilities.view(len(tail_index), self.num_nodes) | |
top_tails_indices = torch.topk(probabilities, k, dim=1).indices | |
for i, (h, r, t) in enumerate(zip(head_index, rel_type, tail_index)): | |
triple_key = (h.item(), r.item(), t.item()) | |
unique_nodes[triple_key].update(top_tails_indices[i].tolist()) | |
# tail to all | |
for rel_id in range(self.num_relations): | |
rel_tensor = torch.tensor([rel_id] * len(tail_index), dtype=torch.long, device=head_index.device) | |
tail_tensor = torch.arange(self.num_nodes, dtype=torch.long, device=head_index.device) | |
# Compute scores for all tails for the current relation | |
probabilities = batch_predict(self.kge_model, tail_index, rel_tensor, tail_tensor, mode = 'head') | |
# Reshape and find the top k tails | |
probabilities = probabilities.view(len(tail_index), self.num_nodes) | |
top_tails_indices = torch.topk(probabilities, k, dim=1).indices | |
for i, (h, r, t) in enumerate(zip(head_index, rel_type, tail_index)): | |
triple_key = (h.item(), r.item(), t.item()) | |
unique_nodes[triple_key].update(top_tails_indices[i].tolist()) | |
# Ensure the number of unique tails does not exceed M and pad if necessary | |
final_neighbors_per_triple = {} | |
for triple, neighbors in unique_nodes.items(): | |
neighbors.discard(triple[0]) # remove the head (it will be added later) | |
neighbors.discard(triple[2]) # remove the tail (it will be added later) | |
if len(neighbors) > M: | |
# sample if more nodes than M were retrieved | |
neighbors = list(random.sample(neighbors, M)) | |
elif len(neighbors) < M: | |
while len(neighbors) < M: | |
rnd_neighbor = random.randint(0, self.num_nodes - 1) | |
if rnd_neighbor not in neighbors: | |
neighbors.add(rnd_neighbor) | |
neighbors = list(neighbors) | |
final_neighbors_per_triple[triple] = list(neighbors) | |
# Creating the Neighbors indices Tensor | |
neighbors_indices = torch.tensor([final_neighbors_per_triple[triple] for triple in final_neighbors_per_triple], dtype=torch.long, device=head_index.device) | |
# Adding Necessary Data | |
# Adding all Relations | |
fixed_relation_list = list(self.mapping_dict['relation'].values()) | |
fixed_relation_tensor = torch.tensor(fixed_relation_list, dtype=torch.long, device=head_index.device) | |
fixed_relation_tensor_repeated = fixed_relation_tensor.unsqueeze(0).repeat(neighbors_indices.size(0), 1) | |
# Concatenate relations tensors | |
concatenated_tensor = torch.cat((fixed_relation_tensor_repeated, neighbors_indices), dim=1) | |
# Adding Head/Tail indices | |
# Reshape head_index, rel_type, and tail_index to [6, 1] | |
head_index = head_index.view(-1, 1) | |
#rel_type = rel_type.view(-1, 1) # the relation vector is not needed, it has already been added | |
tail_index = tail_index.view(-1, 1) | |
# Concatenate tensors | |
concatenated_tensor = torch.cat((head_index, tail_index, concatenated_tensor), dim=1) | |
# Assuming concatenated is a 2D tensor, with each row being a vector of indices | |
batch_size = concatenated_tensor.size(0) | |
# Extract indices for the entire batch | |
# For now this is hardcoded. Wordnet has 11 relations | |
headtail_indices = concatenated_tensor[:, :2] | |
relation_indices = concatenated_tensor[:, 2:13] | |
neighbors_indices = concatenated_tensor[:, 13:] | |
# Flatten the indices for mapping and later reshape back to batch format | |
flat_headtail_indices = headtail_indices.reshape(-1) | |
flat_relation_indices = relation_indices.reshape(-1) | |
flat_neighbors_indices = neighbors_indices.reshape(-1) | |
# Map indices to embeddings | |
flat_headtail_real_embeddings = self.combined_embedding[torch.tensor([self.mapping_dict['entity']['real'][idx.item()] for idx in flat_headtail_indices])] | |
flat_headtail_imaginary_embeddings = self.combined_embedding[torch.tensor([self.mapping_dict['entity']['imaginary'][idx.item()] for idx in flat_headtail_indices])] | |
flat_neighbors_real_embeddings = self.combined_embedding[torch.tensor([self.mapping_dict['entity']['real'][idx.item()] for idx in flat_neighbors_indices])] | |
flat_neighbors_imaginary_embeddings = self.combined_embedding[torch.tensor([self.mapping_dict['entity']['imaginary'][idx.item()] for idx in flat_neighbors_indices])] | |
relation_embeddings = self.combined_embedding[flat_relation_indices] | |
# Reshape embeddings back to batch format | |
headtail_real_embeddings = flat_headtail_real_embeddings.view(len(head_index), headtail_indices.size(1), self.input_dim) | |
headtail_imaginary_embeddings = flat_headtail_imaginary_embeddings.view(len(head_index), headtail_indices.size(1), self.input_dim) | |
neighbors_real_embeddings = flat_neighbors_real_embeddings.view(len(head_index), neighbors_indices.size(1), self.input_dim) | |
neighbors_imaginary_embeddings = flat_neighbors_imaginary_embeddings.view(len(head_index), neighbors_indices.size(1), self.input_dim) | |
relation_embeddings = relation_embeddings.view(len(head_index), relation_indices.size(1), self.input_dim) | |
# Concatenate to form the final embeddings tensor for the batch | |
final_embeddings_batch = torch.cat((headtail_real_embeddings, | |
headtail_imaginary_embeddings, | |
relation_embeddings, | |
neighbors_real_embeddings, | |
neighbors_imaginary_embeddings), dim=1) | |
return final_embeddings_batch, neighbors_indices |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment