Skip to content

Instantly share code, notes, and snippets.

@tomonori-masui
Last active September 22, 2022 05:53
Show Gist options
  • Save tomonori-masui/89306c7c7b9fb70e011507be3ee45838 to your computer and use it in GitHub Desktop.
Save tomonori-masui/89306c7c7b9fb70e011507be3ee45838 to your computer and use it in GitHub Desktop.
Visualizing Graph
import random
from torch_geometric.utils import to_networkx
import networkx as nx
def convert_to_networkx(graph, n_sample=None):
g = to_networkx(graph, node_attrs=["x"])
y = graph.y.numpy()
if n_sample is not None:
sampled_nodes = random.sample(g.nodes, n_sample)
g = g.subgraph(sampled_nodes)
y = y[sampled_nodes]
return g, y
def plot_graph(g, y):
plt.figure(figsize=(9, 7))
nx.draw_spring(g, node_size=30, arrows=False, node_color=y)
plt.show()
g, y = convert_to_networkx(graph, n_sample=1000)
plot_graph(g, y)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment