Skip to content

Instantly share code, notes, and snippets.

@mandli
Created September 17, 2025 14:24
Show Gist options
  • Save mandli/cde3eb21d932688717ce0c9f20915b98 to your computer and use it in GitHub Desktop.
Save mandli/cde3eb21d932688717ce0c9f20915b98 to your computer and use it in GitHub Desktop.
Plot the number of grids and cells in an [AMRClaw](github.com/clawpack/amrclaw) or [GeoClaw](github.com/clawpack/geoclaw) run.
#!/usr/bin/env python
from pathlib import Path
import sys
import os
import numpy as np
article = False
# Plot customization
import matplotlib
if article:
# Markers and line widths
matplotlib.rcParams['lines.linewidth'] = 2.0
matplotlib.rcParams['lines.markersize'] = 6
matplotlib.rcParams['lines.markersize'] = 8
# Font Sizes
matplotlib.rcParams['font.size'] = 16
matplotlib.rcParams['axes.labelsize'] = 16
matplotlib.rcParams['legend.fontsize'] = 12
matplotlib.rcParams['xtick.labelsize'] = 16
matplotlib.rcParams['ytick.labelsize'] = 16
# DPI of output images
matplotlib.rcParams['savefig.dpi'] = 300
import matplotlib.pyplot as plt
# days2seconds = lambda days: days * 60.0**2 * 24.0
# seconds2days = lambda seconds: seconds / (60.0**2 * 24.0)
def set_day_ticks(new_ticks=[0.0, 1.0, 2.0, 3.0, 4.0, 5.0]):
plt.xticks(new_ticks, [str(tick) for tick in new_ticks])
def set_cell_ticks():
# plt.ticklabel_format(style='sci')
locs,labels = plt.yticks()
labels = locs / 1e6
plt.yticks(locs,labels)
# plt.yticks(new_ticks, [str(tick) for tick in new_ticks])
def get_grid_statistics(output_path: Path, landfall: float=0.0,
num_levels: int=7):
# TODO: Detect number of levels
file_list = list(output_path.glob("fort.t*[0-9]"))
file_list.sort()
time = np.empty(len(file_list), dtype=float)
num_grids = np.zeros((time.shape[0], num_levels), dtype=int)
num_cells = np.zeros((time.shape[0], num_levels), dtype=int)
for (n, path) in enumerate(file_list):
with path.open('r') as t_file:
time[n] = (float(t_file.readline().split()[0]) - landfall) / (60.0**2 * 24.0)
t_file.readline()
t_file_num_grids = int(t_file.readline().split()[0])
with (output_path / f"fort.q{path.name[6:]}").open('r') as q_file:
line = "\n"
while line != "":
line = q_file.readline()
if "grid_number" in line:
level = int(q_file.readline().split()[0])
num_grids[n, level - 1] += 1
mx = int(q_file.readline().split()[0])
my = int(q_file.readline().split()[0])
num_cells[n, level - 1] += mx * my
# File checking
if np.sum(num_grids[n,:]) != t_file_num_grids:
raise Exception("Number of grids in fort.t* file and fort.q* file do not match.")
return time, num_grids, num_cells
def plot_grid_statistics(time, num_grids, num_cells, num_levels):
"""Plot cascading time histories per level"""
colors = [ (value / 256.0, value / 256.0, value / 256.0)
for value in [247, 217, 189, 150, 115, 82, 37] ]
proxy_artists = [plt.Rectangle((0, 0), 1, 1, fc=colors[level],
label="Level %s" % (str(level+1))) for level in range(num_levels)]
figs = []
# Number of grids
figs.append(plt.figure())
ax = figs[0].add_subplot(111)
ax.set_yscale('log')
ax.stackplot(time, num_grids.transpose(), colors=colors)
# ax.stackplot(time, num_grids, colors=colors)
ax.set_xlabel('Days from 2025-12-25 00:00 UTC')
# plt.subplots_adjust(left=0.13, bottom=0.12, right=0.90, top=0.90)
ax.set_xlim([0.0, 5.0])
# ax.set_ylim([0,1e4])
set_day_ticks()
ax.set_ylabel('Number of Grids')
ax.set_title("Number of Grids per Level in Time")
ax.legend(proxy_artists, ["Level %s" % (str(level+1)) for level in range(num_levels)], loc=2)
# Number of cells
figs.append(plt.figure())
ax = figs[1].add_subplot(111)
ax.set_yscale('log')
ax.stackplot(time, num_cells.transpose(), colors=colors)
ax.set_xlim([0.0, 5.0])
# ax.set_ylim([0,1e7])
set_day_ticks()
# plt.subplots_adjust(left=0.13, bottom=0.12, right=0.90, top=0.90)
ax.set_xlabel('Days from 2025-12-25 00:00 UTC')
ax.set_ylabel('Number of Cells')
ax.set_title("Number of Cells per Level in Time")
ax.legend(proxy_artists, ["Level %s" % (str(level+1)) for level in range(num_levels)], loc=2)
return figs
if __name__ == "__main__":
base_path = Path(os.environ['DATA_PATH']) / "surge" / "ETC_storms"
output_path = base_path / "WS12_SL00_L7_output"
num_levels = 7
if len(sys.argv) > 1:
output_path = Path(sys.argv[1])
if len(sys.argv) > 2:
num_levels = int(sys.argv[2])
landfall = 0.0
time, num_grids, num_cells = get_grid_statistics(output_path,
landfall=landfall,
num_levels=num_levels)
figs = plot_grid_statistics(time, num_grids, num_cells, num_levels)
# figs[0].savefig("num_grids.png")
# figs[1].savefig("num_cells.png")
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment