Skip to content

Instantly share code, notes, and snippets.

@j6k4m8
Created January 4, 2025 00:32
Show Gist options
  • Save j6k4m8/210287da1b7391a9ba3a79c8db3b93dc to your computer and use it in GitHub Desktop.
Save j6k4m8/210287da1b7391a9ba3a79c8db3b93dc to your computer and use it in GitHub Desktop.
from collections import defaultdict
import networkx as nx
import matplotlib.pyplot as plt
from openalex import OpenAlex, Work
from functools import cache
MAILTO_ADDR = "ADD HERE!"
@cache
def _get_work_memoized(work_slug: str) -> Work:
return OpenAlex(mailto=MAILTO_ADDR).get_work(OpenAlex.get_slug_from_uri(work_slug))
class PaperAncestryNetwork:
def __init__(self, work_id: str, generations: int = 1, mailto: str = MAILTO_ADDR):
self.base_work_id = work_id
self._initialized = False
self._generations = generations
self.oa = OpenAlex(mailto=mailto)
def to_nx(self):
if not self._initialized:
self._load_graph()
return self.graph
def _load_graph(self):
self.graph = nx.DiGraph()
work_queue = [(self.base_work_id, 0)]
visited = set()
self._initialized = True
while work_queue:
current_work_id, current_generation = work_queue.pop(0)
if current_work_id in visited or current_generation > self._generations:
continue
visited.add(current_work_id)
try:
work = _get_work_memoized(current_work_id)
self.graph.add_node(current_work_id, title=work.title, work=work, generation=current_generation)
for reference in work.referenced_works:
reference = self.oa.get_slug_from_uri(reference)
if reference != current_work_id:
self.graph.add_edge(reference, current_work_id)
work_queue.append((reference, current_generation + 1))
except Exception as e:
print(f"Error processing work {current_work_id}: {e}")
g = PaperAncestryNetwork("W3174465650", generations=3)
gg = g.to_nx()
gg = gg.subgraph(n for n in gg.nodes if 'title' in gg.nodes[n])
def layout(g: nx.DiGraph):
# each work item has a .publication_date YYYY-MM-DD. Make X the date ordinal, and Y the number of papers grouped in that month
# Create a dictionary to store the count of papers per month
papers_per_month = defaultdict(int)
# Iterate over all nodes in the graph
for node in g.nodes(data=True):
work = node[1]['work']
# work = _get_work_memoized(work_id)
# pub_date = datetime.datetime.strptime(work.publication_date, "%Y-%m-%d")
pub_date = work.publication_date
month_start = pub_date.replace(day=1)
papers_per_month[month_start] += 1
# Create the layout dictionary
layout_dict = {}
for node in g.nodes(data=True):
work_id = node[0]
work = node[1]["work"]
# pub_date = datetime.datetime.strptime(work.publication_date, "%Y-%m-%d")
pub_date = work.publication_date
month_start = pub_date.replace(day=1)
x = pub_date.toordinal()
y = papers_per_month[month_start] + x
layout_dict[work_id] = (x, y)
return layout_dict
pos = nx.multipartite_layout(gg, "generation", align="horizontal")
# pos = nx.nx_agraph.graphviz_layout(gg, prog="twopi", args="")
plt.figure(figsize=(12, 8))
nx.draw(
gg,
pos=pos,
with_labels=False,
node_size=50,
width=0.5,
font_size=6,
labels={
i: " ".join(n["title"].split()[:2]) + "..." for i, n in gg.nodes(data=True)
},
node_color=[n["generation"] for i, n in gg.nodes(data=True)],
)
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment