Skip to content

Instantly share code, notes, and snippets.

@generalmimon
Last active March 29, 2024 17:46
Show Gist options
  • Save generalmimon/a24d65d1890e73b4036fdd48e7d6772f to your computer and use it in GitHub Desktop.
Save generalmimon/a24d65d1890e73b4036fdd48e7d6772f to your computer and use it in GitHub Desktop.
Python script to generate Kaitai Struct expressions
#!/usr/bin/env python3
# SPDX-FileCopyrightText: 2024 Petr Pucil <[email protected]>
#
# SPDX-License-Identifier: MIT
"""
expr_gen.py - Python script to generate a pair of files `[expr_gen.ksy,
expr_gen.kst]` with various Kaitai Struct expressions, designed to test that
there are no missing parentheses in the translated expressions in any target
language.
Usage: ./expr_gen.py
expr_gen.ksy will contain an exhaustive list of value instances with Kaitai
Struct expressions roughly in the form `(X left_op Y) right_op Z`.
`left_op` and `right_op` take on (almost) all operations supported by the KS
expression language. `X`, `Y` and `Z` take on various primitive data types, but
only if they make sense for the selected combination of `(left_op, right_op)`.
Values of selected data types are generated pseudo-randomly (based on the
`RAND_SEED` constant, which you can change). Efforts are made to avoid invalid
selections of values that would lead to runtime errors (e.g. out of range
string/array indices) for the given operators. The script may sometimes fail to
generate a selection of values considered valid (because it is not very robustly
written); in that case, the corresponding combination of selected operators &
data types will be skipped, and this incident will be logged in the console.
expr_gen.kst will contain assertions that each value instance is evaluated to
the correct (constant) value.
The `[expr_gen.ksy, expr_gen.kst]` pair can be dropped in the
https://github.com/kaitai-io/kaitai_struct_tests repo (`formats/` and `spec/ks/`
folders respectively) and run along with the existing tests.
"""
import itertools
import random
from collections import defaultdict
from enum import Enum, IntEnum, auto
import string
from dataclasses import dataclass
import sys
from typing import Any
import yaml
RAND_SEED = 1784761548
MAX_TRIES_TO_GENERATE_VALUES = 100
TOP_LEVEL_NAME = 'expr_gen'
random.seed(RAND_SEED)
FRUIT_ENUM_INTS = random.sample(range(0, 15 + 1), 2)
class Fruit(IntEnum):
APPLE = FRUIT_ENUM_INTS[0]
BANANA = FRUIT_ENUM_INTS[1]
def __str__(self) -> str:
return f'fruit::{self.name.lower()}'
class DT(Enum):
INT = auto()
FLOAT = auto()
BOOL = auto()
ENUM = auto()
STR = auto()
BYTES = auto()
# ARRAY = auto()
ARRAY_BASE_TYPE = auto()
def __repr__(self):
return f'{self.name}'
def __str__(self):
return repr(self)
OPS = defaultdict(list)
# `%` is not here because the modulo apparently doesn't work with floats in some langugages
NUM_OPS = '+ - * /'.split()
NUM_OPS_SIGNATURES = [
((DT.INT, DT.INT), DT.INT),
((DT.INT, DT.FLOAT), DT.FLOAT),
((DT.FLOAT, DT.INT), DT.FLOAT),
((DT.FLOAT, DT.FLOAT), DT.FLOAT),
]
for op in NUM_OPS:
OPS[op].extend(NUM_OPS_SIGNATURES)
INT_OPS = '% | ^ & << >>'.split()
for op in INT_OPS:
OPS[op].append(((DT.INT, DT.INT), DT.INT))
OPS['.to_s'].append(((DT.INT,), DT.STR))
OPS['.to_i'].append(((DT.FLOAT,), DT.INT))
OPS['.to_i'].append(((DT.STR,), DT.INT))
OPS['.to_i'].append(((DT.STR, DT.INT), DT.INT))
OPS['.to_i'].append(((DT.ENUM,), DT.INT))
OPS['.to_i'].append(((DT.BOOL,), DT.INT))
OPS['.length'].append(((DT.STR,), DT.INT))
OPS['.length'].append(((DT.BYTES,), DT.INT))
# OPS['.length'].append(((DT.ARRAY,), DT.INT))
OPS['.to_s'].append(((DT.BYTES, DT.STR), DT.STR))
OPS['+'].append(((DT.STR, DT.STR), DT.STR))
OPS['.substring'].append(((DT.STR, DT.INT, DT.INT), DT.STR))
OPS['.reverse'].append(((DT.STR,), DT.STR))
OPS['[]'].append(((DT.BYTES, DT.INT), DT.INT))
SUB_OPS = '.first .last .min .max'.split()
SUB_OP_SIGNATURES = [
# ((DT.ARRAY,), DT.ARRAY_BASE_TYPE),
((DT.BYTES,), DT.INT),
]
for op in SUB_OPS:
OPS[op].extend(SUB_OP_SIGNATURES)
for dt in DT:
if dt == DT.ARRAY_BASE_TYPE:
continue
OPS['?:'].append(((DT.BOOL, dt, dt), dt))
OPS['?:'].append(((DT.BOOL, DT.INT, DT.FLOAT), DT.FLOAT))
OPS['?:'].append(((DT.BOOL, DT.FLOAT, DT.INT), DT.FLOAT))
# for op, signatures in OPS.items():
# print(op, signatures)
def generate_value(data_type):
match data_type:
case DT.INT:
return random.randint(0, 15)
case DT.FLOAT:
return random.randint(0, 15_0) / 10
case DT.BOOL:
return random.choice((False, True))
case DT.ENUM:
return random.choice(list(Fruit))
case DT.STR:
length = random.randint(0, 8)
return ''.join(random.choices(string.digits, k=length))
case DT.BYTES:
# Not generating empty byte arrays because that would require the `.as<bytes>` cast, which is not implemented in Go
length = random.randint(1, 8)
if length == 0:
return b''
else:
return ''.join(random.choices(string.digits, k=length)).encode('ASCII')
raise NotImplementedError(data_type)
def generate_values(data_types):
return [generate_value(data_type) for data_type in data_types]
def translate_value_for_ks(value, data_type, context=''):
match data_type:
case DT.INT | DT.FLOAT:
return str(value)
case DT.BOOL:
return ('false', 'true')[value]
case DT.ENUM:
return context + str(value)
case DT.STR:
if "'" in value:
raise NotImplementedError()
else:
return f"'{value}'"
case DT.BYTES:
if len(value) == 0:
return '[].as<bytes>'
else:
return '[' + ', '.join(map(str, value)) + ']'
raise NotImplementedError(data_type)
def translate_values_for_ks(values, data_types):
return [translate_value_for_ks(value, data_type) for value, data_type in zip(values, data_types, strict=True)]
class WrongGeneratedValuesError(ValueError):
pass
def evaluate_op(values, data_types, op):
match op:
case '.to_s':
if data_types[0] == DT.INT:
return str(values[0])
if data_types[0] == DT.BYTES:
val, encoding = values
return val.decode(encoding)
raise NotImplementedError()
case '.to_i':
if tuple(data_types) == (DT.STR, DT.INT):
val, radix = values
if radix not in (2, 8, 10, 16):
raise WrongGeneratedValuesError()
try:
return int(val, base=radix)
except ValueError:
raise WrongGeneratedValuesError()
val, = values
try:
return int(val)
except ValueError:
raise WrongGeneratedValuesError()
case '.length':
val, = values
return len(val)
case '.substring':
val, start, end = values
length = len(val)
if not 0 <= start <= length:
raise WrongGeneratedValuesError()
if not 0 <= end <= length:
raise WrongGeneratedValuesError()
if not start <= end:
raise WrongGeneratedValuesError()
return val[start:end]
case '.reverse':
val, = values
return val[::-1]
case '.first':
arr, = values
try:
return arr[0]
except IndexError:
raise WrongGeneratedValuesError()
case '.last':
arr, = values
try:
return arr[-1]
except IndexError:
raise WrongGeneratedValuesError()
case '.min':
arr, = values
try:
return min(arr)
except ValueError:
raise WrongGeneratedValuesError()
case '.max':
arr, = values
try:
return max(arr)
except ValueError:
raise WrongGeneratedValuesError()
case '?:':
cond, if_true, if_false = values
return (if_true if cond else if_false)
try:
l, r = values
except ValueError:
print(values)
print(data_types)
print(op)
raise
match op:
case '+':
return l + r
case '-':
return l - r
case '*':
return l * r
case '/':
try:
if tuple(data_types) == (DT.INT, DT.INT):
res = l // r
# Let's require a non-negative result here, otherwise we'll run
# into https://github.com/kaitai-io/kaitai_struct/issues/746 in some languages
if res < 0:
raise WrongGeneratedValuesError()
return res
return l / r
except ZeroDivisionError:
raise WrongGeneratedValuesError()
case '%':
try:
return l % r
except ZeroDivisionError:
raise WrongGeneratedValuesError()
case '|':
return l | r
case '^':
return l ^ r
case '&':
return l & r
case '<<':
return l << r
case '>>':
return l >> r
case '[]':
arr, index = values
if index < 0:
raise WrongGeneratedValuesError()
try:
return arr[index]
except IndexError:
raise WrongGeneratedValuesError()
print(values)
print(data_types)
print(op)
raise NotImplementedError()
@dataclass
class ProducedExpr:
expr: str
evald: Any
data_type: DT
left_op_idx: int
right_op_idx: int
left_op_signature_idx: int
right_op_signature_idx: int
@property
def name(self) -> str:
return f'l{self.left_op_idx}_r{self.right_op_idx}_lsig{self.left_op_signature_idx}_rsig{self.right_op_signature_idx}'
@property
def expected_expr(self) -> str:
return translate_value_for_ks(self.evald, self.data_type, context=f'{TOP_LEVEL_NAME}::')
produced_expressions = []
options = itertools.product(enumerate(OPS.keys()), repeat=2)
for (left_op_idx, left_op), (right_op_idx, right_op) in options:
# print((left_op, right_op))
left_signatures = OPS[left_op]
right_signatures = OPS[right_op]
# Left grouping
possible_right_signatures = defaultdict(list)
for signature_idx_val_pair in enumerate(right_signatures):
param_types, _ = signature_idx_val_pair[1]
possible_right_signatures[param_types[0]].append(signature_idx_val_pair)
# print(possible_right_signatures)
for left_op_signature_idx, (left_param_types, left_ret_type) in enumerate(left_signatures):
if left_ret_type not in possible_right_signatures:
# Return type of the left group isn't accepted by the right operator as the first operand => skipping
continue
found_valid_value = False
for i in range(MAX_TRIES_TO_GENERATE_VALUES):
left_values_orig = generate_values(left_param_types)
if left_op == '.to_s' and left_param_types == (DT.BYTES, DT.STR):
left_values_orig[1] = 'ASCII'
try:
left_evald = evaluate_op(left_values_orig, left_param_types, left_op)
found_valid_value = True
break
except WrongGeneratedValuesError:
pass
if not found_valid_value:
raise WrongGeneratedValuesError(f'{MAX_TRIES_TO_GENERATE_VALUES} attempts to generate a value have run out')
left_values = translate_values_for_ks(left_values_orig, left_param_types)
if left_op.startswith('.'):
args = left_values[1:]
args_expr = f'({', '.join(args)})' if len(args) != 0 else ''
left_group_expr = f'{left_values[0]}{left_op}{args_expr}'
else:
if len(left_values) == 2:
if left_op == '[]':
left_group_expr = f'{left_values[0]}[{left_values[1]}]'
else:
left_group_expr = f'({left_values[0]} {left_op} {left_values[1]})'
elif len(left_values) == 3 and left_op == '?:':
left_group_expr = f'({left_values[0]} ? {left_values[1]} : {left_values[2]})'
else:
raise NotImplementedError()
for right_op_signature_idx, (right_param_types, right_ret_type) in possible_right_signatures[left_ret_type]:
found_valid_value = False
for i in range(MAX_TRIES_TO_GENERATE_VALUES):
right_values_orig = [left_evald] + generate_values(right_param_types[1:])
if right_op == '.to_s' and right_param_types == (DT.BYTES, DT.STR):
right_values_orig[1] = 'ASCII'
try:
expr_evald = evaluate_op(right_values_orig, right_param_types, right_op)
found_valid_value = True
break
except WrongGeneratedValuesError:
pass
if not found_valid_value:
print(left_group_expr, ' => ', repr(left_evald), file=sys.stderr)
print(f'Warning: {MAX_TRIES_TO_GENERATE_VALUES} attempts to generate a value have run out for {(left_op, right_op)}', file=sys.stderr)
continue
right_values = translate_values_for_ks(right_values_orig[1:], right_param_types[1:])
if right_op.startswith('.'):
args_expr = f'({', '.join(right_values)})' if len(right_values) != 0 else ''
expr = f'{left_group_expr}{right_op}{args_expr}'
else:
if len(right_values) == 1:
if right_op == '[]':
expr = f'{left_group_expr}[{right_values[0]}]'
else:
expr = f'{left_group_expr} {right_op} {right_values[0]}'
elif len(right_values) == 2 and right_op == '?:':
expr = f'{left_group_expr} ? {right_values[0]} : {right_values[1]}'
else:
print((left_op, right_op))
print(right_values)
raise NotImplementedError((left_op, right_op))
produced_expressions.append(
ProducedExpr(
expr,
expr_evald,
right_ret_type,
left_op_idx,
right_op_idx,
left_op_signature_idx,
right_op_signature_idx,
)
)
# print(expr, ' => ', repr(expr_evald), file=sys.stderr)
with open(f'{TOP_LEVEL_NAME}.ksy', 'w', encoding='utf-8', newline='\n') as ksy_f, \
open(f'{TOP_LEVEL_NAME}.kst', 'w', encoding='utf-8', newline='\n') as kst_f:
ksy ={
'meta': {
'id': TOP_LEVEL_NAME,
},
}
ksy['enums'] = {
'fruit': {},
}
for item in Fruit:
ksy['enums']['fruit'][int(item)] = item.name.lower()
kst = {
'id': TOP_LEVEL_NAME,
'data': 'enum_negative.bin',
'asserts': [],
}
ksy['instances'] = {}
for produced_expr in produced_expressions:
attr_name = produced_expr.name
ksy['instances'][attr_name] = {'value': produced_expr.expr}
kst['asserts'].append({'actual': attr_name, 'expected': produced_expr.expected_expr})
yaml.safe_dump(ksy, ksy_f, sort_keys=False, default_flow_style=False)
yaml.safe_dump(kst, kst_f, sort_keys=False, default_flow_style=False)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment