Created
January 4, 2025 00:32
-
-
Save j6k4m8/210287da1b7391a9ba3a79c8db3b93dc to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from collections import defaultdict | |
import networkx as nx | |
import matplotlib.pyplot as plt | |
from openalex import OpenAlex, Work | |
from functools import cache | |
MAILTO_ADDR = "ADD HERE!" | |
@cache | |
def _get_work_memoized(work_slug: str) -> Work: | |
return OpenAlex(mailto=MAILTO_ADDR).get_work(OpenAlex.get_slug_from_uri(work_slug)) | |
class PaperAncestryNetwork: | |
def __init__(self, work_id: str, generations: int = 1, mailto: str = MAILTO_ADDR): | |
self.base_work_id = work_id | |
self._initialized = False | |
self._generations = generations | |
self.oa = OpenAlex(mailto=mailto) | |
def to_nx(self): | |
if not self._initialized: | |
self._load_graph() | |
return self.graph | |
def _load_graph(self): | |
self.graph = nx.DiGraph() | |
work_queue = [(self.base_work_id, 0)] | |
visited = set() | |
self._initialized = True | |
while work_queue: | |
current_work_id, current_generation = work_queue.pop(0) | |
if current_work_id in visited or current_generation > self._generations: | |
continue | |
visited.add(current_work_id) | |
try: | |
work = _get_work_memoized(current_work_id) | |
self.graph.add_node(current_work_id, title=work.title, work=work, generation=current_generation) | |
for reference in work.referenced_works: | |
reference = self.oa.get_slug_from_uri(reference) | |
if reference != current_work_id: | |
self.graph.add_edge(reference, current_work_id) | |
work_queue.append((reference, current_generation + 1)) | |
except Exception as e: | |
print(f"Error processing work {current_work_id}: {e}") | |
g = PaperAncestryNetwork("W3174465650", generations=3) | |
gg = g.to_nx() | |
gg = gg.subgraph(n for n in gg.nodes if 'title' in gg.nodes[n]) | |
def layout(g: nx.DiGraph): | |
# each work item has a .publication_date YYYY-MM-DD. Make X the date ordinal, and Y the number of papers grouped in that month | |
# Create a dictionary to store the count of papers per month | |
papers_per_month = defaultdict(int) | |
# Iterate over all nodes in the graph | |
for node in g.nodes(data=True): | |
work = node[1]['work'] | |
# work = _get_work_memoized(work_id) | |
# pub_date = datetime.datetime.strptime(work.publication_date, "%Y-%m-%d") | |
pub_date = work.publication_date | |
month_start = pub_date.replace(day=1) | |
papers_per_month[month_start] += 1 | |
# Create the layout dictionary | |
layout_dict = {} | |
for node in g.nodes(data=True): | |
work_id = node[0] | |
work = node[1]["work"] | |
# pub_date = datetime.datetime.strptime(work.publication_date, "%Y-%m-%d") | |
pub_date = work.publication_date | |
month_start = pub_date.replace(day=1) | |
x = pub_date.toordinal() | |
y = papers_per_month[month_start] + x | |
layout_dict[work_id] = (x, y) | |
return layout_dict | |
pos = nx.multipartite_layout(gg, "generation", align="horizontal") | |
# pos = nx.nx_agraph.graphviz_layout(gg, prog="twopi", args="") | |
plt.figure(figsize=(12, 8)) | |
nx.draw( | |
gg, | |
pos=pos, | |
with_labels=False, | |
node_size=50, | |
width=0.5, | |
font_size=6, | |
labels={ | |
i: " ".join(n["title"].split()[:2]) + "..." for i, n in gg.nodes(data=True) | |
}, | |
node_color=[n["generation"] for i, n in gg.nodes(data=True)], | |
) | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment