Last active
March 7, 2023 05:16
-
-
Save iscgar/b77caf9a8b4982a1002111ba46f0e701 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import struct | |
import itertools | |
from base64 import b64encode | |
from retrie.trie import Trie | |
def commonise_group(pat): | |
patr = list(sorted(''.join(reversed(s)) for s in pat)) | |
common = [patr[0]] | |
pat_map = {} | |
matching = longest = len(patr[0]) | |
for w in patr[1:]: | |
for i, c in enumerate(w[:len(common[-1])]): | |
if c != common[-1][i]: | |
break | |
else: | |
raise ValueError("cannot have identical patterns in group") | |
if i >= max(longest, len(w)) // 2: | |
common.append(w) | |
longest = max(longest, len(w)) | |
matching = i | |
else: | |
common_word = ''.join(reversed(common[0])) | |
if len(common) == 1: | |
pat_map[common_word] = ''.join(common_word) | |
elif matching == longest - 1: | |
pat_map[common_word] = '[{}]{}'.format(''.join(w[matching:] for w in common), ''.join(common_word[-matching:])) | |
else: | |
pat_map[common_word] = '({}){}'.format('|'.join(''.join(reversed(w[matching:])) for w in common), ''.join(common_word[-matching:])) | |
common = [w] | |
matching = longest = len(w) | |
common_word = ''.join(reversed(common[0])) | |
if len(common) == 1: | |
pat_map[common_word] = ''.join(common_word) | |
elif matching == longest - 1: | |
pat_map[common_word] = '[{}]{}'.format(''.join(w[matching:] for w in common), ''.join(common_word[-matching:])) | |
else: | |
pat_map[common_word] = '({}){}'.format('|'.join(''.join(reversed(w[matching:])) for w in common), ''.join(common_word[-matching:])) | |
return [pat_map[w] for w in pat if w in pat_map] | |
def shorten_pat(pattern): | |
pat = [''] | |
group_stack = [] | |
in_alt = False | |
in_escape = False | |
last_cat = -1 | |
cats = [] | |
group_prefix = '' | |
for c in pattern: | |
if not in_escape: | |
if in_alt: | |
cat = -1 | |
if '0' <= c <= '9': | |
cat = 0 | |
elif 'A' <= c <= 'Z': | |
cat = 1 | |
elif 'a' <= c <= 'z': | |
cat = 2 | |
if cat != last_cat: | |
cats.sort() | |
if len(cats) > 3 and ord(cats[-1]) - ord(cats[0]) == len(cats) - 1: | |
pat[-1] += '{}-{}'.format(cats[0], cats[-1]) | |
else: | |
pat[-1] += ''.join(cats) | |
del cats[:] | |
if cat != -1: | |
cats.append(c) | |
last_cat = cat | |
continue | |
if c == '\\': | |
in_escape = True | |
elif c == '[': | |
assert not in_alt | |
in_alt = True | |
elif c == ']': | |
assert in_alt | |
in_alt = False | |
elif not in_alt: | |
if c == '(': | |
group_stack.append((group_prefix, pat)) | |
group_prefix, pat = '', [''] | |
continue | |
elif c == ')': | |
gp, opt = group_prefix, commonise_group(pat) | |
group_prefix, pat = group_stack.pop(-1) | |
pat[-1] += '({}{})'.format(gp, '|'.join(opt)) | |
continue | |
elif c == '|': | |
pat.append('') | |
continue | |
elif c == ':' and group_stack and pat[-1] == '?': | |
assert not group_prefix | |
pat[-1] = '' | |
group_prefix = '?:' | |
continue | |
else: | |
in_escape = False | |
pat[-1] += c | |
assert not in_alt and not in_escape and not group_stack | |
assert len(pat) == 1 or pat[-1] | |
return '|'.join(commonise_group(pat)) | |
def regex_for_prefix(prefix, tail_len): | |
# A sequence of all possible binary values | |
# (used to pad the prefix on either side to account for encoding alignment) | |
padding = bytearray(range(256)) | |
# We build a trie in order to try to get the most compressed form of the resulting pattern | |
t = Trie() | |
# A base64 encoding block is 3 bytes long, so we need to account for the position | |
# of the beginning of the prefix in any of an encoding block's slots | |
for i in range(3): | |
lead = b'A' * max(0, i - 1) | |
# If the length of the prefix plus the current encoding block offset | |
# isn't divisable by the length of n encoding block, we need to pad it | |
# in order to get all of the values that could appear after the prefix | |
# in the encoded form | |
pad_len = (3 - (len(prefix) + i) % 3) % 3 | |
pads = b'A' * max(0, pad_len - 1) | |
# Iterate over all of the permutations of padding values for this slot | |
for r in itertools.permutations(padding, int(bool(i)) + int(bool(pad_len))): | |
source = lead | |
if i: | |
source += struct.pack('<B', r[0]) | |
source += prefix | |
if pad_len: | |
source += struct.pack('<B', r[-1]) + pads | |
# We get the encoded value of the prefix (offset by the current slot | |
# index and padded to the next encoding block boundary) | |
encoded = b64encode(source) | |
# However, if the prefix isn't at the beginning of an encoding block, | |
# we only care about the way it affects the encoded prefix itself, | |
# and we don't really care about the value of the bytes that come | |
# before it, so strip the leading bytes (note that since the encoded | |
# length is stricktly bigger than the source length for base64, | |
# stripping an amount equal to the slot index is guaranteed to only | |
# strip the leading padding bytes, but not the encoded prefix). | |
encoded = encoded[i:] | |
# Similarly, if we added padding, we only care about the way it affect | |
# the prefix, but not about the encoded padding byte values, so strip | |
# them as well (again, this is guaranteed to not touch the encoded prefix, | |
# because the encoded size is strictly bigger than the source size for | |
# base64). | |
if pad_len > 0: | |
encoded = encoded[:-pad_len] | |
# Add it to the trie | |
t.add(encoded.decode('ascii')) | |
# Extract a pattern that describes this trie and optimise it a bit | |
pat = shorten_pat(t.pattern()) | |
# Add a pattern for the tail (because we need to at least see this many bytes as well) | |
total_len = len(prefix) + tail_len | |
left = total_len - (len(prefix) + 2) | |
if left > 0: | |
groups = (left + 3) // 4 | |
pat += '(?:{})'.format('|'.join(r'[\+\/A-Za-z0-9]{{{}}}{}'.format(groups * 4 - i, '='*i) for i in range(3))) | |
return pat |
Also, here:
# we only care about the way it affects the encoded prefix itself,
# and we don't really care about the value of the bytes that come
# before it, so strip the leading bytes (note that since the encoded
# length is stricktly bigger than the source length for base64,
# stripping an amount equal to the slot index is guaranteed to only
# strip the leading padding bytes, but not the encoded prefix).
encoded = encoded[i:]
Indeed, this catches most of it, but makes it much harder to handle the result (you need to generate your own prefix after the fact).
If we simply remove this line, we do get some extra chars for the pattern (ideally: [b64-char-pattern] once or twice. practically for current imp: a weird ORing that's longer, although still reasonable).
By fixing it, one can simply do the following for a complete solution:
echo "$input" | grep -oE "$exp" | base64 -d | grep -oE $original_pattern
p.s. I'm using the following script to generate test vectors (obviously, not a perfect one, but helpful enough):
#!/usr/bin/env python3
import random
def range_chrs(a, z):
return [chr(x) for x in range(ord(a), ord(z)+1)]
def pat_opts():
return range_chrs('a', 'z') + range_chrs('A', 'Z') + range_chrs('0', '9')
def ascii_opts():
return range_chrs('0', 'z')
def gen_one(pref, tail):
s = ''
s += ''.join(random.choices(ascii_opts(), k=random.randint(0, 5)))
s += pref
s += ''.join(random.choices(pat_opts(), k=tail))
s += ''.join(random.choices(ascii_opts(), k=random.randint(0, 5)))
return s
import sys
pref = sys.argv[1]
tail = int(sys.argv[2])
print(gen_one(pref, tail))
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Apparently, you can't pull-request to a gist, so see: https://gist.github.com/gofri/2ad0e25430bf89ea70614891bca5d35a/revisions
(account for the length difference between the ascii string and the encoded version)