Skip to content

Instantly share code, notes, and snippets.

@tommylees112
Last active August 25, 2021 11:18
Show Gist options
  • Save tommylees112/ae1a1d26992ab74053b4e07b7f49a338 to your computer and use it in GitHub Desktop.
Save tommylees112/ae1a1d26992ab74053b4e07b7f49a338 to your computer and use it in GitHub Desktop.
Plot the initial results for predictions of an LSTM
from typing import Dict
import xarray as xr
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
def scatter_plot(obs: np.ndarray, sim: np.ndarray, ax = None, scatter_kwargs: Dict = {"marker": "x", "color": "C0", "alpha": 0.3}):
if ax is None:
f, ax = plt.subplots(figsize=(6, 6))
lim = (min([np.nanmin(obs), np.nanmin(sim)]), max([np.nanmax(obs), np.nanmax(sim)]))
ax.scatter(obs, sim, **scatter_kwargs)
ax.plot(lim, lim, ls="--", color="k")
ax.set_xlim(lim)
ax.set_ylim(lim)
ax.set_xlabel("obs")
ax.set_ylabel("sim")
# load the data
data_dir = Path("/path/to/data/")
errors = xr.open_dataset(data_dir / "errors.nc")
preds = xr.open_dataset(data_dir / "preds.nc")
# plot histogram of error metrics
metric = "KGE"
f, ax = plt.subplots(figsize=(12, 4))
ax.hist(errors[metric], alpha=0.6, bins=10, density=True);
ax.set_xlim(0, 1)
ax.set_xlabel(metric);
sns.despine()
# plot scatter
x, y = preds["stage_value_obs"].values.flatten(), preds["stage_value_sim"].values.flatten()
scatter_plot(obs=x, sim=y)
ax = plt.gca()
ax.axhline(60, ls="-", alpha=0.6, color="k")
ax.axhline(95, ls="-", alpha=0.6, color="k");
# plot timeseries of obs vs. sim
sids = np.random.choice(preds.station_id.values, 5)
for sid in sids[:5]:
f, ax = plt.subplots(figsize=(12, 4))
d = preds.sel(station_id=sid)
ax.plot(d["date"], d["stage_value_obs"], ls="--", color="k", alpha=0.6)
ax.plot(d["date"], d["stage_value_sim"], color="C0", alpha=0.6)
ax.set_title(str(sid.values) if isinstance(sid, xr.DataArray) else sid)
sns.despine()
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment