Created
April 29, 2025 17:27
-
-
Save caryan/6b941df93b0598e32eef153bd997e8dd to your computer and use it in GitHub Desktop.
Stim + decoding batch size optimisation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import pandas as pd | |
import plotly.graph_objects as go | |
import plotly.colors as pc | |
from tqdm.auto import tqdm | |
from sinter._decoding._decoding_pymatching import PyMatchingDecoder | |
from sinter._decoding._stim_then_decode_sampler import StimThenDecodeSampler | |
tasks = [] | |
error_rate = 1e-3 | |
for patch_type in ("surface_code:rotated_memory_x", "repetition_code:memory"): | |
for d in (5, 15): | |
tasks.append( | |
sinter.Task( | |
circuit=stim.Circuit.generated( | |
patch_type, | |
distance=d, | |
rounds=3 * d, | |
after_clifford_depolarization=error_rate, | |
before_measure_flip_probability=error_rate / 2, | |
), | |
decoder="pymatching", | |
json_metadata={ | |
"code_type": patch_type.split("_")[0], | |
"d": d, | |
"p": error_rate, | |
"rounds": 3 * d, | |
}, | |
) | |
) | |
for t in tasks: | |
t.detector_error_model = t.circuit.detector_error_model(decompose_errors=True) | |
results = [] | |
total_times = [] | |
max_shots = 1 << 20 | |
shots_per_sample_sweep = 2 ** np.arange(np.log2(max_shots) + 1) | |
for task in tqdm(tasks, desc="Tasks"): | |
sampler = StimThenDecodeSampler( | |
decoder=PyMatchingDecoder(), | |
count_observable_error_combos=False, | |
count_detection_events=False, | |
tmp_dir=None, | |
).compiled_sampler_for_task(task) | |
for shots_per_sample in tqdm(shots_per_sample_sweep, desc="Shots per sample", leave=False): | |
tic = time.time() | |
results.append(sinter.AnonTaskStats()) | |
for j in range(int(max_shots // shots_per_sample)): | |
results[-1] += sampler.sample(int(shots_per_sample)) | |
total_times.append(time.time() - tic) | |
# Create a list to store all the data | |
data = [] | |
# Loop through the results and combine with timing data | |
for i, (result, total_time) in enumerate(zip(results, total_times)): | |
# Calculate which task this belongs to | |
task_idx = i // len(shots_per_sample_sweep) | |
shots_idx = i % len(shots_per_sample_sweep) | |
# Get the corresponding task and shots_per_sample | |
task = tasks[task_idx] | |
shots_per_sample = shots_per_sample_sweep[shots_idx] | |
# Create a row of data | |
row = { | |
"code_type": task.json_metadata["code_type"], | |
"distance": task.json_metadata["d"], | |
"rounds": task.json_metadata["rounds"], | |
"error_rate": task.json_metadata["p"], | |
"shots_per_sample": shots_per_sample, | |
"total_shots": result.shots, | |
"sampling_time": result.seconds, | |
"shots_per_second_sampling": result.shots / result.seconds, | |
"total_time": total_time, | |
"shots_per_second": result.shots / total_time, | |
} | |
data.append(row) | |
# Create the DataFrame | |
df = pd.DataFrame(data) | |
fig = go.Figure() | |
# Keep track of which color we're using | |
color_idx = 0 | |
# Get unique combinations of code_type and distance | |
for code_type in df["code_type"].unique(): | |
for distance in df[df["code_type"] == code_type]["distance"].unique(): | |
# Filter data for this code type and distance | |
mask = (df["code_type"] == code_type) & (df["distance"] == distance) | |
data = df[mask] | |
# Get current color from D3 palette | |
color = pc.qualitative.D3[color_idx] | |
# Add a trace for sampling rate | |
fig.add_trace( | |
go.Scatter( | |
x=data["shots_per_sample"], | |
y=data["shots_per_second_sampling"], | |
name=f"{code_type} d={distance} (sampling)", | |
mode="lines+markers", | |
legendgroup=f"{code_type} d={distance}", | |
line=dict( | |
color=color, | |
dash='solid' | |
), | |
marker=dict(color=color) | |
) | |
) | |
# Add a trace for total rate | |
fig.add_trace( | |
go.Scatter( | |
x=data["shots_per_sample"], | |
y=data["shots_per_second"], | |
name=f"{code_type} d={distance} (total)", | |
mode="lines+markers", | |
legendgroup=f"{code_type} d={distance}", | |
line=dict( | |
color=color, | |
dash='dash' | |
), | |
marker=dict(color=color) | |
) | |
) | |
# Increment color index for next code+distance pair | |
color_idx += 1 | |
fig.update_layout( | |
title="StimThenDecodeSampler with PyMatching Throughput (3d rounds)", | |
xaxis_title="Batch Size (shots per sample)", | |
yaxis_title="Shots per Second", | |
xaxis_type="log", | |
yaxis_type="log", | |
legend_title="Code Patch", | |
width=800, | |
height=600, | |
) | |
fig.update_xaxes(showgrid=True, gridwidth=1, gridcolor="LightGray") | |
fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor="LightGray") | |
fig.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment