Skip to content

Instantly share code, notes, and snippets.

@dcolinmorgan
Created March 19, 2024 06:01
Show Gist options
  • Save dcolinmorgan/bd9a944a0a18385a9552279c6140f327 to your computer and use it in GitHub Desktop.
Save dcolinmorgan/bd9a944a0a18385a9552279c6140f327 to your computer and use it in GitHub Desktop.
psuedocode for message passing algorithm utilizing spatial, ssRNA-seq and ChIP data, via pytorch
import torch
import torch.nn as nn
import torch.nn.functional as F
class GraphConvolution(nn.Module):
def __init__(self, in_features, out_features):
super(GraphConvolution, self).__init__()
self.weight = nn.Parameter(torch.FloatTensor(in_features, out_features))
self.bias = nn.Parameter(torch.FloatTensor(out_features))
def forward(self, x, adj):
support = torch.matmul(x, self.weight)
output = torch.matmul(adj, support) + self.bias
return output
class MessagePassingNetwork(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim):
super(MessagePassingNetwork, self).__init__()
self.gc_chip = GraphConvolution(input_dim, hidden_dim)
self.gc_rna = GraphConvolution(input_dim, hidden_dim)
self.gc_spatial = GraphConvolution(input_dim, hidden_dim)
self.fc = nn.Linear(hidden_dim * 3, output_dim)
def forward(self, chip_data, rna_data, spatial_data, adj):
x_chip = F.relu(self.gc_chip(chip_data, adj))
x_rna = F.relu(self.gc_rna(rna_data, adj))
x_spatial = F.relu(self.gc_spatial(spatial_data, adj))
# Combine features from different data types
x_combined = torch.cat((x_chip.unsqueeze(1), x_rna.unsqueeze(1), x_spatial.unsqueeze(1)), dim=1)
x_combined = torch.mean(x_combined, dim=1) # Average pooling
# Fully connected layer for final prediction
output = self.fc(x_combined)
return output
# Define input data (ChIP-seq, ssRNAseq, and spatial data)
chip_seq_data = ... # Shape: (num_nodes, num_features)
ssRNAseq_data = ... # Shape: (num_nodes, num_features)
spatial_data = ... # Shape: (num_nodes, num_features)
# Define adjacency matrix (assuming a pre-defined graph structure)
adjacency_matrix = ... # Shape: (num_nodes, num_nodes)
# Instantiate and forward pass through the message-passing network
input_dim = chip_seq_data.shape[1] # Assuming all data types have the same input dimension
hidden_dim = 64
output_dim = 32
mpn = MessagePassingNetwork(input_dim, hidden_dim, output_dim)
output = mpn(chip_seq_data, ssRNAseq_data, spatial_data, adjacency_matrix)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment