Skip to content

Instantly share code, notes, and snippets.

@janosh
Created March 13, 2025 18:02
Show Gist options
  • Save janosh/b3da636c5ce68ea667b272e549e290c8 to your computer and use it in GitHub Desktop.
Save janosh/b3da636c5ce68ea667b272e549e290c8 to your computer and use it in GitHub Desktop.
# %%
from __future__ import annotations
import warnings
from collections import defaultdict
from time import perf_counter
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from matcalc.elasticity import ElasticityCalc
from matcalc.eos import EOSCalc
from matcalc.phonon import PhononCalc
from matcalc.relaxation import RelaxCalc
from matcalc.utils import get_universal_calculator
from pymatgen.io.ase import AseAtomsAdaptor
from tqdm import tqdm
from matbench_discovery.data import DataFiles, ase_atoms_from_zip
warnings.filterwarnings("ignore", category=UserWarning, module="matgl")
warnings.filterwarnings("ignore", category=DeprecationWarning, module="spglib")
# %%
n_structures_to_relax = 10
ase_init_atoms = ase_atoms_from_zip(
DataFiles.wbm_initial_atoms.path, limit=n_structures_to_relax
)
pmg_structs = [AseAtomsAdaptor.get_structure(atoms) for atoms in ase_init_atoms]
device = "cuda" if torch.cuda.is_available() else "cpu"
# %%
model_kwargs = {
"CHGNet": {"use_device": device},
"MACE": {"device": device},
}
models = [
(name, get_universal_calculator(name, **model_kwargs[name]))
for name in ("CHGNet", "MACE")
]
print(f"Loaded {len(models)} models: {[name for name, _ in models]}")
# %%
fmax = 0.1
opt = "BFGSLineSearch"
n_jobs = -1 # Use all available CPU cores
sample_size = 10 # Number of structures to process
# %%
# Process all models in sequence, but parallelize structure processing
prop_preds = []
timing_data = defaultdict(dict)
# Create a progress bar for the models
for model_name, model in tqdm(models, desc="Processing models"):
print(f"\nProcessing {model_name}...")
# Create calculators for each property
relax_calc = RelaxCalc(model, fmax=fmax, optimizer=opt)
# Time the relaxation calculations using calc_many (parallel)
print(f" Running relaxation calculations on {len(pmg_structs)} structures...")
start_time = perf_counter()
# Process all structures in parallel
relax_results = list(relax_calc.calc_many(pmg_structs, n_jobs=n_jobs))
relax_time = perf_counter() - start_time
timing_data[model_name]["relax"] = relax_time
print(f" Relaxation completed in {relax_time:.2f} seconds")
# Get the relaxed structures for subsequent calculations
relaxed_structures = [result["final_structure"] for result in relax_results]
# Create other property calculators
elastic_calc = ElasticityCalc(model, fmax=fmax, relax_structure=False)
eos_calc = EOSCalc(model, fmax=fmax, relax_structure=False, optimizer=opt)
phonon_calc = PhononCalc(model, fmax=fmax, relax_structure=False)
# Calculate elastic properties in parallel
print(" Running elasticity calculations...")
start_time = perf_counter()
elastic_results = list(elastic_calc.calc_many(relaxed_structures, n_jobs=n_jobs))
elastic_time = perf_counter() - start_time
timing_data[model_name]["elastic"] = elastic_time
print(f" Elasticity completed in {elastic_time:.2f} seconds")
# Calculate EOS properties in parallel
print(" Running equation of state calculations...")
start_time = perf_counter()
eos_results = list(eos_calc.calc_many(relaxed_structures, n_jobs=n_jobs))
eos_time = perf_counter() - start_time
timing_data[model_name]["eos"] = eos_time
print(f" EOS completed in {eos_time:.2f} seconds")
# Calculate phonon properties in parallel
print(" Running phonon calculations...")
start_time = perf_counter()
phonon_results = list(phonon_calc.calc_many(relaxed_structures, n_jobs=n_jobs))
phonon_time = perf_counter() - start_time
timing_data[model_name]["phonon"] = phonon_time
print(f" Phonon completed in {phonon_time:.2f} seconds")
total_time = relax_time + elastic_time + eos_time + phonon_time
timing_data[model_name]["total"] = total_time
print(f"Total time for {model_name}: {total_time:.2f} seconds")
# Store results for each structure
for i, struct in enumerate(pmg_structs):
model_preds_entry = struct.properties.copy()
# Add timing information
model_preds_entry[f"time_relax_{model_name}"] = (
relax_time / sample_size
) # Average time per structure
model_preds_entry[f"time_elastic_{model_name}"] = elastic_time / sample_size
model_preds_entry[f"time_eos_{model_name}"] = eos_time / sample_size
model_preds_entry[f"time_phonon_{model_name}"] = phonon_time / sample_size
# Store all property results
model_preds_entry[model_name] = {
"relax": relax_results[i],
"elastic": elastic_results[i],
"eos": eos_results[i],
"phonon": phonon_results[i],
"nsites": len(struct),
}
prop_preds.append(model_preds_entry)
# %%
# Convert to DataFrame and calculate total times
df_preds = pd.DataFrame(prop_preds)
for model_name, _ in models:
df_preds[f"time_total_{model_name}"] = (
df_preds[f"time_relax_{model_name}"]
+ df_preds[f"time_elastic_{model_name}"]
+ df_preds[f"time_phonon_{model_name}"]
+ df_preds[f"time_eos_{model_name}"]
)
# %%
print("Summary of results:")
df_preds
# %%
# Plot total calculation time vs number of sites
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
axes = axes.flatten()
for i, (model_name, model) in enumerate(models):
ax = axes[i]
df_preds.plot(x="nsites", y=f"time_total_{model_name}", kind="scatter", ax=ax)
ax.set_xlabel("Number of sites")
ax.set_ylabel("Time (s)")
ax.set_title(f"{model_name} Total Calculation Time")
plt.tight_layout()
plt.show()
# %%
# Plot distribution of calculation times
fig, ax = plt.subplots(figsize=(10, 6))
for model_name, _ in models:
df_preds[f"time_total_{model_name}"].hist(label=model_name, alpha=0.6, ax=ax)
ax.set_xlabel("Total Time (s)")
ax.set_ylabel("Count")
ax.legend()
plt.title("Distribution of Calculation Times")
plt.show()
# %%
# Compare average calculation times across models
avg_times = {
model_name: df_preds[f"time_total_{model_name}"].mean() for model_name, _ in models
}
plt.figure(figsize=(10, 6))
plt.bar(avg_times.keys(), avg_times.values())
plt.xlabel("Model")
plt.ylabel("Average Time (s)")
plt.title("Average Calculation Time by Model")
plt.show()
# %%
# Compare different property calculation times across models
property_types = ["relax", "elastic", "eos", "phonon"]
property_times = {
model_name: [timing_data[model_name][prop] for prop in property_types]
for model_name, _ in models
}
fig, ax = plt.subplots(figsize=(12, 8))
x = np.arange(len(property_types))
width = 0.2
multiplier = 0
for model_name, times in property_times.items():
offset = width * multiplier
rects = ax.bar(x + offset, times, width, label=model_name)
ax.bar_label(rects, fmt="%.1f", padding=3, rotation=90, fontsize=8)
multiplier += 1
ax.set_xticks(x + width, property_types)
ax.set_xlabel("Property Type")
ax.set_ylabel("Time (s)")
ax.set_title("Calculation Time by Property Type and Model")
ax.legend(loc="upper left")
plt.tight_layout()
plt.show()
# %%
# Create a stacked bar chart showing the breakdown of calculation time for each model
fig, ax = plt.subplots(figsize=(10, 6))
bottom = np.zeros(len(models))
for prop in property_types:
times = [timing_data[model_name][prop] for model_name, _ in models]
p = ax.bar([name for name, _ in models], times, bottom=bottom, label=prop)
bottom += times
ax.set_title("Breakdown of Calculation Time by Model")
ax.set_ylabel("Time (s)")
ax.set_xlabel("Model")
ax.legend()
plt.tight_layout()
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment