Last active
March 29, 2024 17:46
-
-
Save generalmimon/a24d65d1890e73b4036fdd48e7d6772f to your computer and use it in GitHub Desktop.
Python script to generate Kaitai Struct expressions
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
# SPDX-FileCopyrightText: 2024 Petr Pucil <[email protected]> | |
# | |
# SPDX-License-Identifier: MIT | |
""" | |
expr_gen.py - Python script to generate a pair of files `[expr_gen.ksy, | |
expr_gen.kst]` with various Kaitai Struct expressions, designed to test that | |
there are no missing parentheses in the translated expressions in any target | |
language. | |
Usage: ./expr_gen.py | |
expr_gen.ksy will contain an exhaustive list of value instances with Kaitai | |
Struct expressions roughly in the form `(X left_op Y) right_op Z`. | |
`left_op` and `right_op` take on (almost) all operations supported by the KS | |
expression language. `X`, `Y` and `Z` take on various primitive data types, but | |
only if they make sense for the selected combination of `(left_op, right_op)`. | |
Values of selected data types are generated pseudo-randomly (based on the | |
`RAND_SEED` constant, which you can change). Efforts are made to avoid invalid | |
selections of values that would lead to runtime errors (e.g. out of range | |
string/array indices) for the given operators. The script may sometimes fail to | |
generate a selection of values considered valid (because it is not very robustly | |
written); in that case, the corresponding combination of selected operators & | |
data types will be skipped, and this incident will be logged in the console. | |
expr_gen.kst will contain assertions that each value instance is evaluated to | |
the correct (constant) value. | |
The `[expr_gen.ksy, expr_gen.kst]` pair can be dropped in the | |
https://github.com/kaitai-io/kaitai_struct_tests repo (`formats/` and `spec/ks/` | |
folders respectively) and run along with the existing tests. | |
""" | |
import itertools | |
import random | |
from collections import defaultdict | |
from enum import Enum, IntEnum, auto | |
import string | |
from dataclasses import dataclass | |
import sys | |
from typing import Any | |
import yaml | |
RAND_SEED = 1784761548 | |
MAX_TRIES_TO_GENERATE_VALUES = 100 | |
TOP_LEVEL_NAME = 'expr_gen' | |
random.seed(RAND_SEED) | |
FRUIT_ENUM_INTS = random.sample(range(0, 15 + 1), 2) | |
class Fruit(IntEnum): | |
APPLE = FRUIT_ENUM_INTS[0] | |
BANANA = FRUIT_ENUM_INTS[1] | |
def __str__(self) -> str: | |
return f'fruit::{self.name.lower()}' | |
class DT(Enum): | |
INT = auto() | |
FLOAT = auto() | |
BOOL = auto() | |
ENUM = auto() | |
STR = auto() | |
BYTES = auto() | |
# ARRAY = auto() | |
ARRAY_BASE_TYPE = auto() | |
def __repr__(self): | |
return f'{self.name}' | |
def __str__(self): | |
return repr(self) | |
OPS = defaultdict(list) | |
# `%` is not here because the modulo apparently doesn't work with floats in some langugages | |
NUM_OPS = '+ - * /'.split() | |
NUM_OPS_SIGNATURES = [ | |
((DT.INT, DT.INT), DT.INT), | |
((DT.INT, DT.FLOAT), DT.FLOAT), | |
((DT.FLOAT, DT.INT), DT.FLOAT), | |
((DT.FLOAT, DT.FLOAT), DT.FLOAT), | |
] | |
for op in NUM_OPS: | |
OPS[op].extend(NUM_OPS_SIGNATURES) | |
INT_OPS = '% | ^ & << >>'.split() | |
for op in INT_OPS: | |
OPS[op].append(((DT.INT, DT.INT), DT.INT)) | |
OPS['.to_s'].append(((DT.INT,), DT.STR)) | |
OPS['.to_i'].append(((DT.FLOAT,), DT.INT)) | |
OPS['.to_i'].append(((DT.STR,), DT.INT)) | |
OPS['.to_i'].append(((DT.STR, DT.INT), DT.INT)) | |
OPS['.to_i'].append(((DT.ENUM,), DT.INT)) | |
OPS['.to_i'].append(((DT.BOOL,), DT.INT)) | |
OPS['.length'].append(((DT.STR,), DT.INT)) | |
OPS['.length'].append(((DT.BYTES,), DT.INT)) | |
# OPS['.length'].append(((DT.ARRAY,), DT.INT)) | |
OPS['.to_s'].append(((DT.BYTES, DT.STR), DT.STR)) | |
OPS['+'].append(((DT.STR, DT.STR), DT.STR)) | |
OPS['.substring'].append(((DT.STR, DT.INT, DT.INT), DT.STR)) | |
OPS['.reverse'].append(((DT.STR,), DT.STR)) | |
OPS['[]'].append(((DT.BYTES, DT.INT), DT.INT)) | |
SUB_OPS = '.first .last .min .max'.split() | |
SUB_OP_SIGNATURES = [ | |
# ((DT.ARRAY,), DT.ARRAY_BASE_TYPE), | |
((DT.BYTES,), DT.INT), | |
] | |
for op in SUB_OPS: | |
OPS[op].extend(SUB_OP_SIGNATURES) | |
for dt in DT: | |
if dt == DT.ARRAY_BASE_TYPE: | |
continue | |
OPS['?:'].append(((DT.BOOL, dt, dt), dt)) | |
OPS['?:'].append(((DT.BOOL, DT.INT, DT.FLOAT), DT.FLOAT)) | |
OPS['?:'].append(((DT.BOOL, DT.FLOAT, DT.INT), DT.FLOAT)) | |
# for op, signatures in OPS.items(): | |
# print(op, signatures) | |
def generate_value(data_type): | |
match data_type: | |
case DT.INT: | |
return random.randint(0, 15) | |
case DT.FLOAT: | |
return random.randint(0, 15_0) / 10 | |
case DT.BOOL: | |
return random.choice((False, True)) | |
case DT.ENUM: | |
return random.choice(list(Fruit)) | |
case DT.STR: | |
length = random.randint(0, 8) | |
return ''.join(random.choices(string.digits, k=length)) | |
case DT.BYTES: | |
# Not generating empty byte arrays because that would require the `.as<bytes>` cast, which is not implemented in Go | |
length = random.randint(1, 8) | |
if length == 0: | |
return b'' | |
else: | |
return ''.join(random.choices(string.digits, k=length)).encode('ASCII') | |
raise NotImplementedError(data_type) | |
def generate_values(data_types): | |
return [generate_value(data_type) for data_type in data_types] | |
def translate_value_for_ks(value, data_type, context=''): | |
match data_type: | |
case DT.INT | DT.FLOAT: | |
return str(value) | |
case DT.BOOL: | |
return ('false', 'true')[value] | |
case DT.ENUM: | |
return context + str(value) | |
case DT.STR: | |
if "'" in value: | |
raise NotImplementedError() | |
else: | |
return f"'{value}'" | |
case DT.BYTES: | |
if len(value) == 0: | |
return '[].as<bytes>' | |
else: | |
return '[' + ', '.join(map(str, value)) + ']' | |
raise NotImplementedError(data_type) | |
def translate_values_for_ks(values, data_types): | |
return [translate_value_for_ks(value, data_type) for value, data_type in zip(values, data_types, strict=True)] | |
class WrongGeneratedValuesError(ValueError): | |
pass | |
def evaluate_op(values, data_types, op): | |
match op: | |
case '.to_s': | |
if data_types[0] == DT.INT: | |
return str(values[0]) | |
if data_types[0] == DT.BYTES: | |
val, encoding = values | |
return val.decode(encoding) | |
raise NotImplementedError() | |
case '.to_i': | |
if tuple(data_types) == (DT.STR, DT.INT): | |
val, radix = values | |
if radix not in (2, 8, 10, 16): | |
raise WrongGeneratedValuesError() | |
try: | |
return int(val, base=radix) | |
except ValueError: | |
raise WrongGeneratedValuesError() | |
val, = values | |
try: | |
return int(val) | |
except ValueError: | |
raise WrongGeneratedValuesError() | |
case '.length': | |
val, = values | |
return len(val) | |
case '.substring': | |
val, start, end = values | |
length = len(val) | |
if not 0 <= start <= length: | |
raise WrongGeneratedValuesError() | |
if not 0 <= end <= length: | |
raise WrongGeneratedValuesError() | |
if not start <= end: | |
raise WrongGeneratedValuesError() | |
return val[start:end] | |
case '.reverse': | |
val, = values | |
return val[::-1] | |
case '.first': | |
arr, = values | |
try: | |
return arr[0] | |
except IndexError: | |
raise WrongGeneratedValuesError() | |
case '.last': | |
arr, = values | |
try: | |
return arr[-1] | |
except IndexError: | |
raise WrongGeneratedValuesError() | |
case '.min': | |
arr, = values | |
try: | |
return min(arr) | |
except ValueError: | |
raise WrongGeneratedValuesError() | |
case '.max': | |
arr, = values | |
try: | |
return max(arr) | |
except ValueError: | |
raise WrongGeneratedValuesError() | |
case '?:': | |
cond, if_true, if_false = values | |
return (if_true if cond else if_false) | |
try: | |
l, r = values | |
except ValueError: | |
print(values) | |
print(data_types) | |
print(op) | |
raise | |
match op: | |
case '+': | |
return l + r | |
case '-': | |
return l - r | |
case '*': | |
return l * r | |
case '/': | |
try: | |
if tuple(data_types) == (DT.INT, DT.INT): | |
res = l // r | |
# Let's require a non-negative result here, otherwise we'll run | |
# into https://github.com/kaitai-io/kaitai_struct/issues/746 in some languages | |
if res < 0: | |
raise WrongGeneratedValuesError() | |
return res | |
return l / r | |
except ZeroDivisionError: | |
raise WrongGeneratedValuesError() | |
case '%': | |
try: | |
return l % r | |
except ZeroDivisionError: | |
raise WrongGeneratedValuesError() | |
case '|': | |
return l | r | |
case '^': | |
return l ^ r | |
case '&': | |
return l & r | |
case '<<': | |
return l << r | |
case '>>': | |
return l >> r | |
case '[]': | |
arr, index = values | |
if index < 0: | |
raise WrongGeneratedValuesError() | |
try: | |
return arr[index] | |
except IndexError: | |
raise WrongGeneratedValuesError() | |
print(values) | |
print(data_types) | |
print(op) | |
raise NotImplementedError() | |
@dataclass | |
class ProducedExpr: | |
expr: str | |
evald: Any | |
data_type: DT | |
left_op_idx: int | |
right_op_idx: int | |
left_op_signature_idx: int | |
right_op_signature_idx: int | |
@property | |
def name(self) -> str: | |
return f'l{self.left_op_idx}_r{self.right_op_idx}_lsig{self.left_op_signature_idx}_rsig{self.right_op_signature_idx}' | |
@property | |
def expected_expr(self) -> str: | |
return translate_value_for_ks(self.evald, self.data_type, context=f'{TOP_LEVEL_NAME}::') | |
produced_expressions = [] | |
options = itertools.product(enumerate(OPS.keys()), repeat=2) | |
for (left_op_idx, left_op), (right_op_idx, right_op) in options: | |
# print((left_op, right_op)) | |
left_signatures = OPS[left_op] | |
right_signatures = OPS[right_op] | |
# Left grouping | |
possible_right_signatures = defaultdict(list) | |
for signature_idx_val_pair in enumerate(right_signatures): | |
param_types, _ = signature_idx_val_pair[1] | |
possible_right_signatures[param_types[0]].append(signature_idx_val_pair) | |
# print(possible_right_signatures) | |
for left_op_signature_idx, (left_param_types, left_ret_type) in enumerate(left_signatures): | |
if left_ret_type not in possible_right_signatures: | |
# Return type of the left group isn't accepted by the right operator as the first operand => skipping | |
continue | |
found_valid_value = False | |
for i in range(MAX_TRIES_TO_GENERATE_VALUES): | |
left_values_orig = generate_values(left_param_types) | |
if left_op == '.to_s' and left_param_types == (DT.BYTES, DT.STR): | |
left_values_orig[1] = 'ASCII' | |
try: | |
left_evald = evaluate_op(left_values_orig, left_param_types, left_op) | |
found_valid_value = True | |
break | |
except WrongGeneratedValuesError: | |
pass | |
if not found_valid_value: | |
raise WrongGeneratedValuesError(f'{MAX_TRIES_TO_GENERATE_VALUES} attempts to generate a value have run out') | |
left_values = translate_values_for_ks(left_values_orig, left_param_types) | |
if left_op.startswith('.'): | |
args = left_values[1:] | |
args_expr = f'({', '.join(args)})' if len(args) != 0 else '' | |
left_group_expr = f'{left_values[0]}{left_op}{args_expr}' | |
else: | |
if len(left_values) == 2: | |
if left_op == '[]': | |
left_group_expr = f'{left_values[0]}[{left_values[1]}]' | |
else: | |
left_group_expr = f'({left_values[0]} {left_op} {left_values[1]})' | |
elif len(left_values) == 3 and left_op == '?:': | |
left_group_expr = f'({left_values[0]} ? {left_values[1]} : {left_values[2]})' | |
else: | |
raise NotImplementedError() | |
for right_op_signature_idx, (right_param_types, right_ret_type) in possible_right_signatures[left_ret_type]: | |
found_valid_value = False | |
for i in range(MAX_TRIES_TO_GENERATE_VALUES): | |
right_values_orig = [left_evald] + generate_values(right_param_types[1:]) | |
if right_op == '.to_s' and right_param_types == (DT.BYTES, DT.STR): | |
right_values_orig[1] = 'ASCII' | |
try: | |
expr_evald = evaluate_op(right_values_orig, right_param_types, right_op) | |
found_valid_value = True | |
break | |
except WrongGeneratedValuesError: | |
pass | |
if not found_valid_value: | |
print(left_group_expr, ' => ', repr(left_evald), file=sys.stderr) | |
print(f'Warning: {MAX_TRIES_TO_GENERATE_VALUES} attempts to generate a value have run out for {(left_op, right_op)}', file=sys.stderr) | |
continue | |
right_values = translate_values_for_ks(right_values_orig[1:], right_param_types[1:]) | |
if right_op.startswith('.'): | |
args_expr = f'({', '.join(right_values)})' if len(right_values) != 0 else '' | |
expr = f'{left_group_expr}{right_op}{args_expr}' | |
else: | |
if len(right_values) == 1: | |
if right_op == '[]': | |
expr = f'{left_group_expr}[{right_values[0]}]' | |
else: | |
expr = f'{left_group_expr} {right_op} {right_values[0]}' | |
elif len(right_values) == 2 and right_op == '?:': | |
expr = f'{left_group_expr} ? {right_values[0]} : {right_values[1]}' | |
else: | |
print((left_op, right_op)) | |
print(right_values) | |
raise NotImplementedError((left_op, right_op)) | |
produced_expressions.append( | |
ProducedExpr( | |
expr, | |
expr_evald, | |
right_ret_type, | |
left_op_idx, | |
right_op_idx, | |
left_op_signature_idx, | |
right_op_signature_idx, | |
) | |
) | |
# print(expr, ' => ', repr(expr_evald), file=sys.stderr) | |
with open(f'{TOP_LEVEL_NAME}.ksy', 'w', encoding='utf-8', newline='\n') as ksy_f, \ | |
open(f'{TOP_LEVEL_NAME}.kst', 'w', encoding='utf-8', newline='\n') as kst_f: | |
ksy ={ | |
'meta': { | |
'id': TOP_LEVEL_NAME, | |
}, | |
} | |
ksy['enums'] = { | |
'fruit': {}, | |
} | |
for item in Fruit: | |
ksy['enums']['fruit'][int(item)] = item.name.lower() | |
kst = { | |
'id': TOP_LEVEL_NAME, | |
'data': 'enum_negative.bin', | |
'asserts': [], | |
} | |
ksy['instances'] = {} | |
for produced_expr in produced_expressions: | |
attr_name = produced_expr.name | |
ksy['instances'][attr_name] = {'value': produced_expr.expr} | |
kst['asserts'].append({'actual': attr_name, 'expected': produced_expr.expected_expr}) | |
yaml.safe_dump(ksy, ksy_f, sort_keys=False, default_flow_style=False) | |
yaml.safe_dump(kst, kst_f, sort_keys=False, default_flow_style=False) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment