Created
December 6, 2022 08:06
-
-
Save bsidhom/9d4c001d67b4df9cdb98b12a09992658 to your computer and use it in GitHub Desktop.
A toy CSV parser combinator from scratch
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
from abc import ABC, abstractmethod | |
from enum import Enum | |
from typing import Any, Callable, ForwardRef, Generic, TypeAlias, TypeVar, cast | |
class RecordSeparator(Enum): | |
VALUE = "RecordSeparator" | |
class Field(): | |
def __init__(self, value: str) -> None: | |
self._value = value | |
def value(self): | |
return self._value | |
def __str__(self): | |
return f"Field({self._value})" | |
T = TypeVar("T") | |
U = TypeVar("U") | |
class Success(Generic[T]): | |
def __init__(self, result: T, rest: str) -> None: | |
self._result = result | |
self._rest = rest | |
def result(self) -> T: | |
return self._result | |
def rest(self) -> str: | |
return self._rest | |
def __str__(self) -> str: | |
return f"Success({self._result}, {repr(self._rest)})" | |
class Failure(): | |
def __init__(self, reason: str) -> None: | |
self._reason = reason | |
def reason(self) -> str: | |
return self._reason | |
def __str__(self) -> str: | |
return f"parse failed: {self._reason}" | |
ParseResult: TypeAlias = Success[T] | Failure | |
class Parser(ABC, Generic[T]): | |
@abstractmethod | |
def parse(self, s: str) -> ParseResult[T]: | |
pass | |
def map(self, f: Callable[[T], U]) -> "Parser[U]": | |
return pmap(self, f) | |
def and_then(self, p: "Parser[U]") -> "Parser[tuple[T, U]]": | |
return and_then(self, p) | |
def functional_parser(f: Callable[[str], ParseResult[T]]) -> Parser[T]: | |
class FuncParser(Parser[U]): | |
def __init__(self, f: Callable[[str], ParseResult[T]]): | |
self._func = f | |
def parse(self, s: str) -> ParseResult[T]: | |
return self._func(s) | |
return FuncParser(f) | |
def literal(lit: str) -> Parser[str]: | |
def parse(s: str) -> ParseResult[str]: | |
if s.startswith(lit): | |
return Success(s[:len(lit)], s[len(lit):]) | |
return Failure(f"\"{s}\" did not match \"{lit}\"") | |
return functional_parser(parse) | |
def pair(left: Parser[T], right: Parser[U]) -> Parser[tuple[T, U]]: | |
def parse(s: str) -> ParseResult[tuple[T, U]]: | |
a = left.parse(s) | |
if isinstance(a, Success): | |
b = right.parse(a.rest()) | |
if isinstance(b, Success): | |
return Success((a.result(), b.result()), b.rest()) | |
else: | |
return b | |
else: | |
return a | |
return functional_parser(parse) | |
def pmap(p: Parser[T], f: Callable[[T], U]) -> Parser[U]: | |
def parser(s: str) -> ParseResult[U]: | |
t = p.parse(s) | |
if isinstance(t, Success): | |
return Success(f(t.result()), t.rest()) | |
else: | |
return cast(ParseResult[U], t) | |
return functional_parser(parser) | |
def and_then(a: Parser[T], b: Parser[U]) -> Parser[tuple[T, U]]: | |
def parse(s) -> ParseResult[tuple[T, U]]: | |
t = a.parse(s) | |
if isinstance(t, Success): | |
u = b.parse(t.rest()) | |
if isinstance(u, Success): | |
return Success((t.result(), u.result()), u.rest()) | |
else: | |
return u | |
else: | |
return t | |
return functional_parser(parse) | |
def left(left: Parser[T], right: Parser[Any]) -> Parser[T]: | |
def extract_left(result: tuple[T, Any]) -> T: | |
return result[0] | |
return pmap(pair(left, right), extract_left) | |
def right(left: Parser[Any], right: Parser[T]) -> Parser[T]: | |
def extract_right(result: tuple[Any, T]) -> T: | |
return result[1] | |
return pair(left, right).map(extract_right) | |
def alt(parsers: list[Parser[T]]) -> Parser[T]: | |
def parse(s: str) -> ParseResult[T]: | |
for p in parsers: | |
result = p.parse(s) | |
if isinstance(result, Success): | |
return result | |
return Failure(f"did not match any alternatives: \"{s}\"") | |
return functional_parser(parse) | |
def zero_or_more(p: Parser[T]) -> Parser[list[T]]: | |
def parse(s) -> ParseResult[list[T]]: | |
rest = s | |
result = [] | |
while True: | |
r = p.parse(rest) | |
if isinstance(r, Success): | |
result.append(r.result()) | |
rest = r.rest() | |
else: | |
return Success(result, rest) | |
return functional_parser(parse) | |
def take_while0(f: Callable[[str], bool]) -> Parser[str]: | |
def parse(s: str) -> ParseResult[str]: | |
index = 0 | |
while index < len(s) and f(s[index]): | |
index += 1 | |
return Success(s[:index], s[index:]) | |
return functional_parser(parse) | |
def take_while1(f: Callable[[str], bool]) -> Parser[str]: | |
def parse(s: str) -> ParseResult[str]: | |
index = 0 | |
while index < len(s) and f(s[index]): | |
index += 1 | |
if index == 0: | |
return Failure(f"did not match at least one character: \"{s}\"") | |
return Success(s[:index], s[index:]) | |
return functional_parser(parse) | |
def eof() -> Parser[str]: | |
def parse(s: str) -> ParseResult[str]: | |
if len(s) == 0: | |
return Success("", "") | |
return Failure(f"expected EOF, got \"{s}\"") | |
return functional_parser(parse) | |
def csv() -> Parser[list[list[str]]]: | |
return pair(record(), | |
zero_or_more(trailing_record())).map(lambda x: [x[0]] + x[1]) | |
def record() -> Parser[list[str]]: | |
return pair(field(), | |
zero_or_more(trailing_field())).map(lambda x: [x[0]] + x[1]) | |
def trailing_record() -> Parser[list[str]]: | |
return right(crlf(), record()) | |
def trailing_field() -> Parser[str]: | |
return right(comma(), field()) | |
def field() -> Parser[str]: | |
return alt([escaped(), non_escaped()]) | |
def escaped() -> Parser[str]: | |
# The take_while1/take_while0 song and dance is necessary to prevent | |
# infinite recursion while reading textdata. To avoid left recursion, we | |
# make sure to match non-empty values up front (wrapped in a zero_or_more) | |
# and then optionally finish off with empty text data. | |
inner_value = pair( | |
zero_or_more( | |
alt([ | |
take_while1(is_valid_textdata), | |
comma(), | |
carriage_return(), | |
line_feed(), | |
quoted_dquote() | |
])).map(lambda xs: "".join(xs)), | |
take_while0(is_valid_textdata)).map(lambda xs: "".join(xs)) | |
return left(right(dquote(), inner_value), dquote()) | |
def non_escaped() -> Parser[str]: | |
return take_while0(is_valid_textdata) | |
def comma() -> Parser[str]: | |
return literal(",") | |
def carriage_return() -> Parser[str]: | |
return literal("\r") | |
def dquote() -> Parser[str]: | |
return literal("\"") | |
def quoted_dquote() -> Parser[str]: | |
# The underlying value should only contain a single dquote. | |
return pair(dquote(), dquote()).map(lambda x: "\"") | |
def line_feed() -> Parser[str]: | |
return literal("\n") | |
def crlf() -> Parser[RecordSeparator]: | |
return pair(carriage_return(), | |
line_feed()).map(lambda _: RecordSeparator.VALUE) | |
def is_valid_textdata(c: str) -> bool: | |
n = ord(c) | |
if n >= 0x20 and n <= 0x21: | |
return True | |
if n >= 0x23 and n <= 0x2b: | |
return True | |
if n >= 0x2d and n <= 0x7e: | |
return True | |
return False | |
def main(): | |
with open("/tmp/input.csv") as f: | |
s = f.read() | |
print(csv().parse(s)) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment