Created
March 11, 2024 20:51
-
-
Save hwayne/dd5c33e41a94d7f3d242f9728bf7e47d to your computer and use it in GitHub Desktop.
expand_version.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import xml.etree.ElementTree as ET | |
from xml.etree.ElementTree import Element | |
from copy import deepcopy | |
from argparse import ArgumentParser | |
from dataclasses import dataclass | |
from string import Template | |
from pathlib import Path | |
import typing as t | |
#Common issue is that I need to have multiple slightly different versions of the same spec, this is a helper to do that. | |
# TODO rename spec to 'file' | |
# TODO let 'files' define strings for specific numbers for easy control | |
def parse_args(): | |
parser = ArgumentParser() | |
parser.add_argument("file", help="xml file to convert") | |
parser.add_argument("--spec", required=False, help="spec in the file to convert. Default is all.") | |
parser.add_argument("--version", required=False, help="which version of the spec. Should only be used if --spec is also used. Default is all. TODO find how argparse works for better documentation of flag limitations") | |
parser.add_argument("-d", "--dryrun", action="store_true", help="print the expansion to STDOUT instead of writing files.") | |
# Arguments to control if we're updating just the file or also the state spaces | |
return parser.parse_args() | |
@dataclass | |
class VersionRange: | |
start: int | |
finish: t.Optional[int] | |
def __init__(self, start: str | int, finish): | |
if start == '': # These are actually strings rn, not Nones, of form '' | |
start = 1 | |
if finish: | |
finish = int(finish) | |
self.start = int(start) | |
self.finish = finish | |
if self.start and self.finish: | |
assert self.start <= self.finish | |
def max_version(self) -> int: | |
if not self.finish: | |
return self.start # guaranteed only maximal for n- switches | |
return self.finish | |
def contains(self, i: int) -> bool: | |
if self.start == 0: # Only happens intentionally | |
return False | |
if not self.finish: # Missing! | |
return i >= self.start | |
return self.start <= i <= self.finish | |
def expand_on_attrib(on_str: str) -> VersionRange: | |
if "-" in on_str: | |
a,b = on_str.split("-") | |
return VersionRange(a, b) | |
else: | |
a = int(on_str) | |
return VersionRange(start=a, finish=a) | |
def get_on(s: Element) -> VersionRange: | |
return expand_on_attrib(s.attrib["on"]) # could also be _ | |
def tree_to_text(tree) -> str: | |
return "".join(tree.itertext()) | |
@dataclass | |
class SpecVersion: | |
name: str | |
version: int | |
text: str | |
ext: str | |
def filename(self): | |
return f"{self.name}__{self.version:0=2}" | |
def __str__(self): | |
return Template(self.text).substitute({"name": self.filename()}) # For TLA stuff | |
class Metafile(Element): # make this totally its own thing | |
... | |
def create_spec_version(spec_root: Element, version: int) -> SpecVersion: | |
new_version = deepcopy(spec_root) | |
for switch in new_version.findall('.//s'): | |
if not get_on(switch).contains(version): | |
for child in switch.iter(): | |
# includes switch ^ | |
child.text = "" | |
# A tag's *text* is the text between start and first child | |
# A tag's *tail* is the text between close and the next tag | |
# So all the text INSIDE switch is switch.text + switch.child.(text + tail) | |
if child != switch: | |
child.tail = "" | |
return SpecVersion( | |
name=new_version.attrib["name"], | |
version=version+int(new_version.get("start-from", 1))-1, | |
text=tree_to_text(new_version), | |
ext=new_version.attrib.get("ext", "tla") | |
) | |
def create_all_spec_versions(spec_root: Element) -> list[SpecVersion]: | |
num_versions = 0 | |
switches = map(get_on, spec_root.findall('.//s')) # Is here where we do the name-replace? | |
for v in switches: | |
num_versions = max(num_versions, v.max_version()) | |
out = [] | |
for i in range(1, num_versions+1): | |
out.append(create_spec_version(spec_root, i)) | |
return out | |
def expand_version(args): | |
tree = ET.parse(args.file) | |
folder = tree.getroot().attrib["folder"] | |
ext = tree.getroot().attrib.get("ext") #backwards compatibility | |
out: list[SpecVersion] = [] | |
if args.spec: | |
spec_root = tree.find(f"spec[@name='{args.spec}']") | |
assert spec_root is not None # did we get the name wrong | |
if args.version: | |
out = [create_spec_version(spec_root, int(args.version))] | |
else: | |
out = create_all_spec_versions(spec_root) | |
else: | |
specs = tree.findall(f"spec") | |
for spec_root in specs: | |
out += create_all_spec_versions(spec_root) | |
# TODO split this out into "expand version" and "Write to files" | |
if args.dryrun: | |
return [str(spec) for spec in out] | |
else: | |
for spec in out: | |
to_write = str(spec) | |
Path(folder).mkdir(exist_ok=True, parents=True) | |
if spec.ext: | |
out_path = Path(folder) / f"{spec.filename()}.{spec.ext}" | |
else: | |
out_path = Path(folder) / f"{spec.filename()}.{ext}" | |
if out_path.exists(): | |
# Preserve metadata at top of file | |
parts = out_path.read_text().split("!!!") | |
parts[-1] = to_write | |
to_write = "!!!".join(parts) | |
out_path.write_text(to_write) | |
return [] | |
def main(): | |
args = parse_args() | |
out = expand_version(args) | |
if args.dryrun: | |
for spec in out: | |
print(spec) | |
if __name__ == "__main__": | |
main() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment