Skip to content

Instantly share code, notes, and snippets.

@caryan
Created April 29, 2025 17:27
Show Gist options
  • Save caryan/6b941df93b0598e32eef153bd997e8dd to your computer and use it in GitHub Desktop.
Save caryan/6b941df93b0598e32eef153bd997e8dd to your computer and use it in GitHub Desktop.
Stim + decoding batch size optimisation
import numpy as np
import pandas as pd
import plotly.graph_objects as go
import plotly.colors as pc
from tqdm.auto import tqdm
from sinter._decoding._decoding_pymatching import PyMatchingDecoder
from sinter._decoding._stim_then_decode_sampler import StimThenDecodeSampler
tasks = []
error_rate = 1e-3
for patch_type in ("surface_code:rotated_memory_x", "repetition_code:memory"):
for d in (5, 15):
tasks.append(
sinter.Task(
circuit=stim.Circuit.generated(
patch_type,
distance=d,
rounds=3 * d,
after_clifford_depolarization=error_rate,
before_measure_flip_probability=error_rate / 2,
),
decoder="pymatching",
json_metadata={
"code_type": patch_type.split("_")[0],
"d": d,
"p": error_rate,
"rounds": 3 * d,
},
)
)
for t in tasks:
t.detector_error_model = t.circuit.detector_error_model(decompose_errors=True)
results = []
total_times = []
max_shots = 1 << 20
shots_per_sample_sweep = 2 ** np.arange(np.log2(max_shots) + 1)
for task in tqdm(tasks, desc="Tasks"):
sampler = StimThenDecodeSampler(
decoder=PyMatchingDecoder(),
count_observable_error_combos=False,
count_detection_events=False,
tmp_dir=None,
).compiled_sampler_for_task(task)
for shots_per_sample in tqdm(shots_per_sample_sweep, desc="Shots per sample", leave=False):
tic = time.time()
results.append(sinter.AnonTaskStats())
for j in range(int(max_shots // shots_per_sample)):
results[-1] += sampler.sample(int(shots_per_sample))
total_times.append(time.time() - tic)
# Create a list to store all the data
data = []
# Loop through the results and combine with timing data
for i, (result, total_time) in enumerate(zip(results, total_times)):
# Calculate which task this belongs to
task_idx = i // len(shots_per_sample_sweep)
shots_idx = i % len(shots_per_sample_sweep)
# Get the corresponding task and shots_per_sample
task = tasks[task_idx]
shots_per_sample = shots_per_sample_sweep[shots_idx]
# Create a row of data
row = {
"code_type": task.json_metadata["code_type"],
"distance": task.json_metadata["d"],
"rounds": task.json_metadata["rounds"],
"error_rate": task.json_metadata["p"],
"shots_per_sample": shots_per_sample,
"total_shots": result.shots,
"sampling_time": result.seconds,
"shots_per_second_sampling": result.shots / result.seconds,
"total_time": total_time,
"shots_per_second": result.shots / total_time,
}
data.append(row)
# Create the DataFrame
df = pd.DataFrame(data)
fig = go.Figure()
# Keep track of which color we're using
color_idx = 0
# Get unique combinations of code_type and distance
for code_type in df["code_type"].unique():
for distance in df[df["code_type"] == code_type]["distance"].unique():
# Filter data for this code type and distance
mask = (df["code_type"] == code_type) & (df["distance"] == distance)
data = df[mask]
# Get current color from D3 palette
color = pc.qualitative.D3[color_idx]
# Add a trace for sampling rate
fig.add_trace(
go.Scatter(
x=data["shots_per_sample"],
y=data["shots_per_second_sampling"],
name=f"{code_type} d={distance} (sampling)",
mode="lines+markers",
legendgroup=f"{code_type} d={distance}",
line=dict(
color=color,
dash='solid'
),
marker=dict(color=color)
)
)
# Add a trace for total rate
fig.add_trace(
go.Scatter(
x=data["shots_per_sample"],
y=data["shots_per_second"],
name=f"{code_type} d={distance} (total)",
mode="lines+markers",
legendgroup=f"{code_type} d={distance}",
line=dict(
color=color,
dash='dash'
),
marker=dict(color=color)
)
)
# Increment color index for next code+distance pair
color_idx += 1
fig.update_layout(
title="StimThenDecodeSampler with PyMatching Throughput (3d rounds)",
xaxis_title="Batch Size (shots per sample)",
yaxis_title="Shots per Second",
xaxis_type="log",
yaxis_type="log",
legend_title="Code Patch",
width=800,
height=600,
)
fig.update_xaxes(showgrid=True, gridwidth=1, gridcolor="LightGray")
fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor="LightGray")
fig.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment