Created
April 11, 2025 07:27
-
-
Save a4lg/e6d5049118958afad8f646944f3e5676 to your computer and use it in GitHub Desktop.
Prover: ensure that specific times of loops on certain groups will be sufficient to prove convergence
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#! /bin/env python3 | |
# SPDX-License-Identifier: MIT | |
# SPDX-FileCopyrightText: Copyright (C) 2025 Tsukasa OI <[email protected]>. | |
import sys | |
import z3 | |
VARS: dict[str, list[list[any]]] = {} | |
N_ITERS = 1 | |
def extensions(exts: list[str] | str) -> list[str]: | |
return exts if isinstance(exts, list) else [exts] | |
def prepare_extension(name: str): | |
if name not in VARS: | |
VARS[name] = [[z3.Bool(f'ext_{name}_{x}_0')] for x in range(N_ITERS + 1)] | |
def get_extension_src(name: str, idx: int) -> any: | |
prepare_extension(name) | |
return VARS[name][idx][-1] | |
def get_extension_dst(name: str, idx: int) -> tuple[any, any]: | |
src = get_extension_src(name, idx) | |
dst = z3.Bool(f'ext_{name}_{idx}_{len(VARS[name][idx])}') | |
VARS[name][idx].append(dst) | |
return (src, dst) | |
def imply(ext_from: list[str] | str, ext_to: list[str] | str) -> list[any]: | |
ext_from = extensions(ext_from) | |
ext_to = extensions(ext_to) | |
srcs = [[] for _ in range(N_ITERS + 1)] | |
syns = [] | |
for name in ext_from: | |
for i in range(N_ITERS + 1): | |
src = get_extension_src(name, i) | |
srcs[i].append(src) | |
for name in ext_to: | |
for i in range(N_ITERS + 1): | |
(src, dst) = get_extension_dst(name, i) | |
syns.append(dst == z3.Or(src, *srcs[i])) | |
return syns | |
def imply_and(ext_from: list[str] | str, ext_to: list[str] | str) -> list[any]: | |
ext_from = extensions(ext_from) | |
ext_to = extensions(ext_to) | |
srcs = [[] for _ in range(N_ITERS + 1)] | |
syns = [] | |
for name in ext_from: | |
for i in range(N_ITERS + 1): | |
src = get_extension_src(name, i) | |
srcs[i].append(src) | |
for name in ext_to: | |
for i in range(N_ITERS + 1): | |
(src, dst) = get_extension_dst(name, i) | |
syns.append(dst == z3.Or(src, z3.And(*srcs[i]))) | |
return syns | |
def group(name: str, exts: list[str]) -> list[any]: | |
syns = [] | |
syns.extend(imply(name, exts)) | |
syns.extend(imply_and(exts, name)) | |
return syns | |
def rel_passes() -> list[any]: | |
syns = [] | |
for v in VARS.values(): | |
for i in range(N_ITERS): | |
syns.append(v[i][-1] == v[i + 1][0]) | |
return syns | |
def consistency_passes() -> list[any]: | |
syns = [] | |
for v in VARS.values(): | |
syns.append(v[N_ITERS - 1][-1] == v[N_ITERS][-1]) | |
return syns | |
syns = [] | |
syns.extend(imply('zvbb', 'zvkb')) | |
for _ in range(3): | |
syns.extend(group('zvkn', ['zvkned', 'zvknhb', 'zvkb', 'zvkt'])) | |
syns.extend(group('zvknc', ['zvkn', 'zvbc'])) | |
syns.extend(group('zvkng', ['zvkn', 'zvkg'])) | |
syns.extend(group('zvks', ['zvksed', 'zvksh', 'zvkb', 'zvkt'])) | |
syns.extend(group('zvksc', ['zvks', 'zvbc'])) | |
syns.extend(group('zvksg', ['zvks', 'zvkg'])) | |
syns.extend(imply(['zvknhb', 'zvbc'], 'zve64x')) | |
syns.extend(imply(['zvbb', 'zvkb', 'zvkg', 'zvkned', 'zvknha', 'zvksed', 'zvksh'], 'zve32x')) | |
for _ in range(2): | |
syns.extend(group('zkn', ['zbkb', 'zbkc', 'zbkx', 'zkne', 'zknd', 'zknh'])) | |
syns.extend(group('zks', ['zbkb', 'zbkc', 'zbkx', 'zksed', 'zksh'])) | |
syns.extend(group('zk', ['zkn', 'zkr', 'zkt'])) | |
syns.extend(imply('zacas', 'zaamo')) | |
syns.extend(group('a', ['zalrsc', 'zaamo'])) | |
syns.extend(group('b', ['zba', 'zbb', 'zbs'])) | |
syns.extend(imply('zcf', ['zca', 'f'])) | |
syns.extend(imply('zcd', ['zca', 'd'])) | |
syns.extend(imply(['zcmop', 'zcb'], 'zca')) | |
syns.extend(imply('zhinx', 'zhinxmin')) | |
syns.extend(imply(['zdinx', 'zhinxmin'], 'zfinx')) | |
syns.extend(imply('zvfh', 'zvfhmin')) | |
syns.extend(imply('zvfhmin', 'zve32f')) | |
syns.extend(imply('v', 'zve64d')) | |
syns.extend(imply('zve64d', ['zve64f', 'd'])) | |
syns.extend(imply('zve64f', ['zve64x', 'zve32f'])) | |
syns.extend(imply('zve64x', 'zve32x')) | |
syns.extend(imply('zve32f', ['zve32x', 'f'])) | |
syns.extend(imply('zfh', 'zfhmin')) | |
syns.extend(imply('q', 'd')) | |
syns.extend(imply(['d', 'zfhmin', 'zfa'], 'f')) | |
syns.extend(imply('c', 'zca')) | |
syns.extend(imply_and(['c', 'd'], 'zcd')) | |
syns.extend(imply_and(['c', 'f'], 'zcf')) # comment out this line to test RV64 | |
syns.extend(imply(['zicntr', 'zihpm', 'zkr', 'f', 'zfinx', 'zve32x'], 'zicsr')) | |
syns.extend(rel_passes()) | |
def DeMorganNot(and_clauses): | |
# Convert AND clauses and get the complement expression | |
# (e.g. for X && Y [and_clauses], return !(X && Y) == (!X || !Y)) | |
return z3.Or(*[z3.Not(p) for p in and_clauses]) | |
def FindCounterexamples(name, constraints): | |
print(f'Whether {name} has a counterexample... ', file=sys.stderr, end='') | |
sys.stderr.flush() | |
solver = z3.Solver() | |
for constraint in constraints: | |
solver.add(constraint) | |
result = solver.check() | |
if result == z3.sat: | |
print('found!\n\nCounterexample:', file=sys.stderr) | |
model = solver.model() | |
if True: | |
q_in = {} | |
q_out_0 = {} | |
q_out_1 = {} | |
for name in VARS.keys(): | |
q_in[f'ext_{name}_0_0'] = name | |
q_out_0[f'ext_{name}_0_{len(VARS[name][0]) - 1}'] = name | |
q_out_1[f'ext_{name}_1_{len(VARS[name][1]) - 1}'] = name | |
q_in_vars = set(q_in.keys()) | |
q_out_0_vars = set(q_out_0.keys()) | |
q_out_1_vars = set(q_out_1.keys()) | |
v_in = {} | |
v_out_0 = {} | |
v_out_1 = {} | |
for d in sorted(model.decls(), key=str): | |
name = str(d) | |
if name in q_in_vars: | |
v_in[q_in[name]] = d | |
if name in q_out_0_vars: | |
v_out_0[q_out_0[name]] = d | |
if name in q_out_1_vars: | |
v_out_1[q_out_1[name]] = d | |
print('Inputs:') | |
for name in sorted(VARS.keys()): | |
print(f'\t{name:9} = {str(model[v_in[name]]) if name in v_in else 'Any'}') | |
print('Different extensions:') | |
for name in sorted(VARS.keys()): | |
if str(model[v_out_0[name]]) != str(model[v_out_1[name]]): | |
print('\t' + name) | |
else: | |
for d in sorted(model.decls(), key=str): | |
name = str(d) | |
print(f'\t{name} = {model[d]}') | |
sys.exit(1) | |
else: | |
print('not found.', file=sys.stderr) | |
FindCounterexamples('Convergence', syns + [DeMorganNot(consistency_passes())]) | |
sys.exit(0) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment