Skip to content

Instantly share code, notes, and snippets.

@a4lg
Created April 11, 2025 07:27
Show Gist options
  • Save a4lg/e6d5049118958afad8f646944f3e5676 to your computer and use it in GitHub Desktop.
Save a4lg/e6d5049118958afad8f646944f3e5676 to your computer and use it in GitHub Desktop.
Prover: ensure that specific times of loops on certain groups will be sufficient to prove convergence
#! /bin/env python3
# SPDX-License-Identifier: MIT
# SPDX-FileCopyrightText: Copyright (C) 2025 Tsukasa OI <[email protected]>.
import sys
import z3
VARS: dict[str, list[list[any]]] = {}
N_ITERS = 1
def extensions(exts: list[str] | str) -> list[str]:
return exts if isinstance(exts, list) else [exts]
def prepare_extension(name: str):
if name not in VARS:
VARS[name] = [[z3.Bool(f'ext_{name}_{x}_0')] for x in range(N_ITERS + 1)]
def get_extension_src(name: str, idx: int) -> any:
prepare_extension(name)
return VARS[name][idx][-1]
def get_extension_dst(name: str, idx: int) -> tuple[any, any]:
src = get_extension_src(name, idx)
dst = z3.Bool(f'ext_{name}_{idx}_{len(VARS[name][idx])}')
VARS[name][idx].append(dst)
return (src, dst)
def imply(ext_from: list[str] | str, ext_to: list[str] | str) -> list[any]:
ext_from = extensions(ext_from)
ext_to = extensions(ext_to)
srcs = [[] for _ in range(N_ITERS + 1)]
syns = []
for name in ext_from:
for i in range(N_ITERS + 1):
src = get_extension_src(name, i)
srcs[i].append(src)
for name in ext_to:
for i in range(N_ITERS + 1):
(src, dst) = get_extension_dst(name, i)
syns.append(dst == z3.Or(src, *srcs[i]))
return syns
def imply_and(ext_from: list[str] | str, ext_to: list[str] | str) -> list[any]:
ext_from = extensions(ext_from)
ext_to = extensions(ext_to)
srcs = [[] for _ in range(N_ITERS + 1)]
syns = []
for name in ext_from:
for i in range(N_ITERS + 1):
src = get_extension_src(name, i)
srcs[i].append(src)
for name in ext_to:
for i in range(N_ITERS + 1):
(src, dst) = get_extension_dst(name, i)
syns.append(dst == z3.Or(src, z3.And(*srcs[i])))
return syns
def group(name: str, exts: list[str]) -> list[any]:
syns = []
syns.extend(imply(name, exts))
syns.extend(imply_and(exts, name))
return syns
def rel_passes() -> list[any]:
syns = []
for v in VARS.values():
for i in range(N_ITERS):
syns.append(v[i][-1] == v[i + 1][0])
return syns
def consistency_passes() -> list[any]:
syns = []
for v in VARS.values():
syns.append(v[N_ITERS - 1][-1] == v[N_ITERS][-1])
return syns
syns = []
syns.extend(imply('zvbb', 'zvkb'))
for _ in range(3):
syns.extend(group('zvkn', ['zvkned', 'zvknhb', 'zvkb', 'zvkt']))
syns.extend(group('zvknc', ['zvkn', 'zvbc']))
syns.extend(group('zvkng', ['zvkn', 'zvkg']))
syns.extend(group('zvks', ['zvksed', 'zvksh', 'zvkb', 'zvkt']))
syns.extend(group('zvksc', ['zvks', 'zvbc']))
syns.extend(group('zvksg', ['zvks', 'zvkg']))
syns.extend(imply(['zvknhb', 'zvbc'], 'zve64x'))
syns.extend(imply(['zvbb', 'zvkb', 'zvkg', 'zvkned', 'zvknha', 'zvksed', 'zvksh'], 'zve32x'))
for _ in range(2):
syns.extend(group('zkn', ['zbkb', 'zbkc', 'zbkx', 'zkne', 'zknd', 'zknh']))
syns.extend(group('zks', ['zbkb', 'zbkc', 'zbkx', 'zksed', 'zksh']))
syns.extend(group('zk', ['zkn', 'zkr', 'zkt']))
syns.extend(imply('zacas', 'zaamo'))
syns.extend(group('a', ['zalrsc', 'zaamo']))
syns.extend(group('b', ['zba', 'zbb', 'zbs']))
syns.extend(imply('zcf', ['zca', 'f']))
syns.extend(imply('zcd', ['zca', 'd']))
syns.extend(imply(['zcmop', 'zcb'], 'zca'))
syns.extend(imply('zhinx', 'zhinxmin'))
syns.extend(imply(['zdinx', 'zhinxmin'], 'zfinx'))
syns.extend(imply('zvfh', 'zvfhmin'))
syns.extend(imply('zvfhmin', 'zve32f'))
syns.extend(imply('v', 'zve64d'))
syns.extend(imply('zve64d', ['zve64f', 'd']))
syns.extend(imply('zve64f', ['zve64x', 'zve32f']))
syns.extend(imply('zve64x', 'zve32x'))
syns.extend(imply('zve32f', ['zve32x', 'f']))
syns.extend(imply('zfh', 'zfhmin'))
syns.extend(imply('q', 'd'))
syns.extend(imply(['d', 'zfhmin', 'zfa'], 'f'))
syns.extend(imply('c', 'zca'))
syns.extend(imply_and(['c', 'd'], 'zcd'))
syns.extend(imply_and(['c', 'f'], 'zcf')) # comment out this line to test RV64
syns.extend(imply(['zicntr', 'zihpm', 'zkr', 'f', 'zfinx', 'zve32x'], 'zicsr'))
syns.extend(rel_passes())
def DeMorganNot(and_clauses):
# Convert AND clauses and get the complement expression
# (e.g. for X && Y [and_clauses], return !(X && Y) == (!X || !Y))
return z3.Or(*[z3.Not(p) for p in and_clauses])
def FindCounterexamples(name, constraints):
print(f'Whether {name} has a counterexample... ', file=sys.stderr, end='')
sys.stderr.flush()
solver = z3.Solver()
for constraint in constraints:
solver.add(constraint)
result = solver.check()
if result == z3.sat:
print('found!\n\nCounterexample:', file=sys.stderr)
model = solver.model()
if True:
q_in = {}
q_out_0 = {}
q_out_1 = {}
for name in VARS.keys():
q_in[f'ext_{name}_0_0'] = name
q_out_0[f'ext_{name}_0_{len(VARS[name][0]) - 1}'] = name
q_out_1[f'ext_{name}_1_{len(VARS[name][1]) - 1}'] = name
q_in_vars = set(q_in.keys())
q_out_0_vars = set(q_out_0.keys())
q_out_1_vars = set(q_out_1.keys())
v_in = {}
v_out_0 = {}
v_out_1 = {}
for d in sorted(model.decls(), key=str):
name = str(d)
if name in q_in_vars:
v_in[q_in[name]] = d
if name in q_out_0_vars:
v_out_0[q_out_0[name]] = d
if name in q_out_1_vars:
v_out_1[q_out_1[name]] = d
print('Inputs:')
for name in sorted(VARS.keys()):
print(f'\t{name:9} = {str(model[v_in[name]]) if name in v_in else 'Any'}')
print('Different extensions:')
for name in sorted(VARS.keys()):
if str(model[v_out_0[name]]) != str(model[v_out_1[name]]):
print('\t' + name)
else:
for d in sorted(model.decls(), key=str):
name = str(d)
print(f'\t{name} = {model[d]}')
sys.exit(1)
else:
print('not found.', file=sys.stderr)
FindCounterexamples('Convergence', syns + [DeMorganNot(consistency_passes())])
sys.exit(0)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment