Created
April 10, 2019 16:17
-
-
Save davipatti/3cdfaf6094996efa7bbc6ec23986b5dc to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
"""2019-04-10 David Pattinson | |
Plot the parameter samples from a MrBayes run. | |
- Run this script in a directory that contains *.p files, or that contains | |
subdirectories (however deep) that contain *.p files. | |
- Any .p files in the same directory will be plotted on the same trace. | |
E.g. if HA.run1.p and HA.run2.p are in the same directory, their traces | |
will be plotted on top of each other. | |
- A png and pdf are saved in the same directory that the *.p files are | |
located in. | |
""" | |
from collections import defaultdict | |
import pandas as pd | |
import matplotlib.pyplot as plt | |
import math | |
import os | |
plt.style.use("seaborn-whitegrid") | |
def read_pfile(path): | |
return pd.read_csv(path, sep="\t", index_col=0, skiprows=1) | |
def summarise_samples(pfiles): | |
""" | |
Args: | |
pfiles (iterable): Contains paths to .p files to plot. | |
""" | |
pn = pd.Panel({p.rstrip(".p").lstrip("./"): read_pfile(p) for p in pfiles}) | |
ncols = 4 | |
nrows = math.ceil(len(pn.minor_axis)/ncols) | |
fig, _ = plt.subplots( | |
nrows=nrows, ncols=ncols, figsize=(20, nrows*3), sharex=True) | |
first = True | |
for parameter, ax in zip(pn.minor_axis, fig.axes): | |
df = pn[:, :, parameter] | |
df.plot(alpha=0.5, ax=ax) | |
ax.set_prop_cycle(None) # Resets color cycle | |
rolling_mean = df.rolling(1000).mean() | |
rolling_mean.columns = map( | |
lambda x: x + " rolling 1000", rolling_mean.columns) | |
rolling_mean.plot(ax=ax, ls="--", lw=2) | |
if not first: | |
ax.legend().remove() | |
first = False | |
ax.set_title(parameter) | |
if __name__ == "__main__": | |
groups = defaultdict(list) | |
for root, dirs, files in os.walk("."): | |
for file in files: | |
if file.endswith(".p"): | |
groups[root].append(os.path.join(root, file)) | |
for g in groups: | |
summarise_samples(groups[g]) | |
plt.savefig("{}/mcmc-samples.pdf".format(g), bbox_inches="tight") | |
plt.savefig("{}/mcmc-samples.png".format(g), bbox_inches="tight") | |
plt.close() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment