Last active
February 3, 2023 04:02
-
-
Save avivajpeyi/1d69bb2c9044ee861bd472d1e0ba11e5 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import bilby | |
| import numpy as np | |
| from bilby.gw.result import CompactBinaryCoalescenceResult | |
| import matplotlib.pyplot as plt | |
| import time | |
| import os | |
| from tqdm.auto import tqdm | |
| import pandas as pd | |
| VALID_EXTENSIONS = ["hdf5", "json", "pkl"] | |
| def generate_fake_result(n=10000): | |
| bilby.utils.command_line_args.bilby_test_mode = False | |
| priors = bilby.gw.prior.BBHPriorDict() | |
| priors["geocent_time"] = 2 | |
| injection_parameters = priors.sample() | |
| return CompactBinaryCoalescenceResult( | |
| label="label", | |
| outdir="outdir", | |
| sampler="nestle", | |
| search_parameter_keys=list(priors.keys()), | |
| fixed_parameter_keys=list(), | |
| priors=priors, | |
| sampler_kwargs=dict(test="test", func=lambda x: x), | |
| injection_parameters=injection_parameters, | |
| meta_data={}, | |
| posterior=pd.DataFrame(priors.sample(n)), | |
| ) | |
| def get_filesize_in_mb(filepath): | |
| return os.path.getsize(filepath) / 1e6 | |
| def time_save(r: CompactBinaryCoalescenceResult, extension: str): | |
| t0 = time.time() | |
| r.save_to_file(filename=f'test.{extension}', outdir=".", extension=extension) | |
| return time.time() - t0 | |
| def time_load(extension: str): | |
| t0 = time.time() | |
| filename = f"test.{extension}" | |
| if extension == "hdf5": | |
| CompactBinaryCoalescenceResult.from_hdf5(filename) | |
| elif extension == "json": | |
| CompactBinaryCoalescenceResult.from_json(filename) | |
| elif extension == "pkl": | |
| CompactBinaryCoalescenceResult.from_pickle(filename) | |
| return time.time() - t0 | |
| def collect_timing_and_mem_data(r: CompactBinaryCoalescenceResult, extension): | |
| t_save = time_save(r, extension) | |
| t_load = time_load(extension) | |
| filesize = get_filesize_in_mb(f"test.{extension}") | |
| return t_save, t_load, filesize | |
| def plot_timing_and_mem_data(df): | |
| fig, axes = plt.subplots(3, 1, figsize=(6, 8)) | |
| axes[0].set_xlabel("Number of samples") | |
| axes[0].set_ylabel("Save Time (s)") | |
| axes[1].set_ylabel("Load Time (s)") | |
| axes[2].set_ylabel("Filesize (MB)") | |
| for extension in VALID_EXTENSIONS: | |
| df[df.extension == extension].plot( | |
| x="n", y="t_save", ax=axes[0], label=extension | |
| ) | |
| df[df.extension == extension].plot( | |
| x="n", y="t_load", ax=axes[1], label=extension | |
| ) | |
| df[df.extension == extension].plot( | |
| x="n", y="filesize", ax=axes[2], label=extension | |
| ) | |
| plt.tight_layout() | |
| for ax in axes: | |
| ax.legend(frameon=False) | |
| ax.set_xscale("log") | |
| ax.set_yscale("log") | |
| plt.show() | |
| if __name__ == "__main__": | |
| r = generate_fake_result() | |
| ns = np.geomspace(1e4, 1e6, 10) | |
| data = [] | |
| for n in ns: | |
| r = generate_fake_result(n=int(n)) | |
| for extension in VALID_EXTENSIONS: | |
| t_save, t_load, filesize = collect_timing_and_mem_data(r, extension) | |
| data.append(dict(n=n, extension=extension, t_save=t_save, t_load=t_load, filesize=filesize)) | |
| df = pd.DataFrame(data) | |
| plot_timing_and_mem_data(df) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment