Created
March 13, 2025 18:02
-
-
Save janosh/b3da636c5ce68ea667b272e549e290c8 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# %% | |
from __future__ import annotations | |
import warnings | |
from collections import defaultdict | |
from time import perf_counter | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import pandas as pd | |
import torch | |
from matcalc.elasticity import ElasticityCalc | |
from matcalc.eos import EOSCalc | |
from matcalc.phonon import PhononCalc | |
from matcalc.relaxation import RelaxCalc | |
from matcalc.utils import get_universal_calculator | |
from pymatgen.io.ase import AseAtomsAdaptor | |
from tqdm import tqdm | |
from matbench_discovery.data import DataFiles, ase_atoms_from_zip | |
warnings.filterwarnings("ignore", category=UserWarning, module="matgl") | |
warnings.filterwarnings("ignore", category=DeprecationWarning, module="spglib") | |
# %% | |
n_structures_to_relax = 10 | |
ase_init_atoms = ase_atoms_from_zip( | |
DataFiles.wbm_initial_atoms.path, limit=n_structures_to_relax | |
) | |
pmg_structs = [AseAtomsAdaptor.get_structure(atoms) for atoms in ase_init_atoms] | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# %% | |
model_kwargs = { | |
"CHGNet": {"use_device": device}, | |
"MACE": {"device": device}, | |
} | |
models = [ | |
(name, get_universal_calculator(name, **model_kwargs[name])) | |
for name in ("CHGNet", "MACE") | |
] | |
print(f"Loaded {len(models)} models: {[name for name, _ in models]}") | |
# %% | |
fmax = 0.1 | |
opt = "BFGSLineSearch" | |
n_jobs = -1 # Use all available CPU cores | |
sample_size = 10 # Number of structures to process | |
# %% | |
# Process all models in sequence, but parallelize structure processing | |
prop_preds = [] | |
timing_data = defaultdict(dict) | |
# Create a progress bar for the models | |
for model_name, model in tqdm(models, desc="Processing models"): | |
print(f"\nProcessing {model_name}...") | |
# Create calculators for each property | |
relax_calc = RelaxCalc(model, fmax=fmax, optimizer=opt) | |
# Time the relaxation calculations using calc_many (parallel) | |
print(f" Running relaxation calculations on {len(pmg_structs)} structures...") | |
start_time = perf_counter() | |
# Process all structures in parallel | |
relax_results = list(relax_calc.calc_many(pmg_structs, n_jobs=n_jobs)) | |
relax_time = perf_counter() - start_time | |
timing_data[model_name]["relax"] = relax_time | |
print(f" Relaxation completed in {relax_time:.2f} seconds") | |
# Get the relaxed structures for subsequent calculations | |
relaxed_structures = [result["final_structure"] for result in relax_results] | |
# Create other property calculators | |
elastic_calc = ElasticityCalc(model, fmax=fmax, relax_structure=False) | |
eos_calc = EOSCalc(model, fmax=fmax, relax_structure=False, optimizer=opt) | |
phonon_calc = PhononCalc(model, fmax=fmax, relax_structure=False) | |
# Calculate elastic properties in parallel | |
print(" Running elasticity calculations...") | |
start_time = perf_counter() | |
elastic_results = list(elastic_calc.calc_many(relaxed_structures, n_jobs=n_jobs)) | |
elastic_time = perf_counter() - start_time | |
timing_data[model_name]["elastic"] = elastic_time | |
print(f" Elasticity completed in {elastic_time:.2f} seconds") | |
# Calculate EOS properties in parallel | |
print(" Running equation of state calculations...") | |
start_time = perf_counter() | |
eos_results = list(eos_calc.calc_many(relaxed_structures, n_jobs=n_jobs)) | |
eos_time = perf_counter() - start_time | |
timing_data[model_name]["eos"] = eos_time | |
print(f" EOS completed in {eos_time:.2f} seconds") | |
# Calculate phonon properties in parallel | |
print(" Running phonon calculations...") | |
start_time = perf_counter() | |
phonon_results = list(phonon_calc.calc_many(relaxed_structures, n_jobs=n_jobs)) | |
phonon_time = perf_counter() - start_time | |
timing_data[model_name]["phonon"] = phonon_time | |
print(f" Phonon completed in {phonon_time:.2f} seconds") | |
total_time = relax_time + elastic_time + eos_time + phonon_time | |
timing_data[model_name]["total"] = total_time | |
print(f"Total time for {model_name}: {total_time:.2f} seconds") | |
# Store results for each structure | |
for i, struct in enumerate(pmg_structs): | |
model_preds_entry = struct.properties.copy() | |
# Add timing information | |
model_preds_entry[f"time_relax_{model_name}"] = ( | |
relax_time / sample_size | |
) # Average time per structure | |
model_preds_entry[f"time_elastic_{model_name}"] = elastic_time / sample_size | |
model_preds_entry[f"time_eos_{model_name}"] = eos_time / sample_size | |
model_preds_entry[f"time_phonon_{model_name}"] = phonon_time / sample_size | |
# Store all property results | |
model_preds_entry[model_name] = { | |
"relax": relax_results[i], | |
"elastic": elastic_results[i], | |
"eos": eos_results[i], | |
"phonon": phonon_results[i], | |
"nsites": len(struct), | |
} | |
prop_preds.append(model_preds_entry) | |
# %% | |
# Convert to DataFrame and calculate total times | |
df_preds = pd.DataFrame(prop_preds) | |
for model_name, _ in models: | |
df_preds[f"time_total_{model_name}"] = ( | |
df_preds[f"time_relax_{model_name}"] | |
+ df_preds[f"time_elastic_{model_name}"] | |
+ df_preds[f"time_phonon_{model_name}"] | |
+ df_preds[f"time_eos_{model_name}"] | |
) | |
# %% | |
print("Summary of results:") | |
df_preds | |
# %% | |
# Plot total calculation time vs number of sites | |
fig, axes = plt.subplots(2, 2, figsize=(12, 10)) | |
axes = axes.flatten() | |
for i, (model_name, model) in enumerate(models): | |
ax = axes[i] | |
df_preds.plot(x="nsites", y=f"time_total_{model_name}", kind="scatter", ax=ax) | |
ax.set_xlabel("Number of sites") | |
ax.set_ylabel("Time (s)") | |
ax.set_title(f"{model_name} Total Calculation Time") | |
plt.tight_layout() | |
plt.show() | |
# %% | |
# Plot distribution of calculation times | |
fig, ax = plt.subplots(figsize=(10, 6)) | |
for model_name, _ in models: | |
df_preds[f"time_total_{model_name}"].hist(label=model_name, alpha=0.6, ax=ax) | |
ax.set_xlabel("Total Time (s)") | |
ax.set_ylabel("Count") | |
ax.legend() | |
plt.title("Distribution of Calculation Times") | |
plt.show() | |
# %% | |
# Compare average calculation times across models | |
avg_times = { | |
model_name: df_preds[f"time_total_{model_name}"].mean() for model_name, _ in models | |
} | |
plt.figure(figsize=(10, 6)) | |
plt.bar(avg_times.keys(), avg_times.values()) | |
plt.xlabel("Model") | |
plt.ylabel("Average Time (s)") | |
plt.title("Average Calculation Time by Model") | |
plt.show() | |
# %% | |
# Compare different property calculation times across models | |
property_types = ["relax", "elastic", "eos", "phonon"] | |
property_times = { | |
model_name: [timing_data[model_name][prop] for prop in property_types] | |
for model_name, _ in models | |
} | |
fig, ax = plt.subplots(figsize=(12, 8)) | |
x = np.arange(len(property_types)) | |
width = 0.2 | |
multiplier = 0 | |
for model_name, times in property_times.items(): | |
offset = width * multiplier | |
rects = ax.bar(x + offset, times, width, label=model_name) | |
ax.bar_label(rects, fmt="%.1f", padding=3, rotation=90, fontsize=8) | |
multiplier += 1 | |
ax.set_xticks(x + width, property_types) | |
ax.set_xlabel("Property Type") | |
ax.set_ylabel("Time (s)") | |
ax.set_title("Calculation Time by Property Type and Model") | |
ax.legend(loc="upper left") | |
plt.tight_layout() | |
plt.show() | |
# %% | |
# Create a stacked bar chart showing the breakdown of calculation time for each model | |
fig, ax = plt.subplots(figsize=(10, 6)) | |
bottom = np.zeros(len(models)) | |
for prop in property_types: | |
times = [timing_data[model_name][prop] for model_name, _ in models] | |
p = ax.bar([name for name, _ in models], times, bottom=bottom, label=prop) | |
bottom += times | |
ax.set_title("Breakdown of Calculation Time by Model") | |
ax.set_ylabel("Time (s)") | |
ax.set_xlabel("Model") | |
ax.legend() | |
plt.tight_layout() | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment