Skip to content

Instantly share code, notes, and snippets.

@bsidhom
Created December 6, 2022 08:06
Show Gist options
  • Save bsidhom/9d4c001d67b4df9cdb98b12a09992658 to your computer and use it in GitHub Desktop.
Save bsidhom/9d4c001d67b4df9cdb98b12a09992658 to your computer and use it in GitHub Desktop.
A toy CSV parser combinator from scratch
#!/usr/bin/env python3
from abc import ABC, abstractmethod
from enum import Enum
from typing import Any, Callable, ForwardRef, Generic, TypeAlias, TypeVar, cast
class RecordSeparator(Enum):
VALUE = "RecordSeparator"
class Field():
def __init__(self, value: str) -> None:
self._value = value
def value(self):
return self._value
def __str__(self):
return f"Field({self._value})"
T = TypeVar("T")
U = TypeVar("U")
class Success(Generic[T]):
def __init__(self, result: T, rest: str) -> None:
self._result = result
self._rest = rest
def result(self) -> T:
return self._result
def rest(self) -> str:
return self._rest
def __str__(self) -> str:
return f"Success({self._result}, {repr(self._rest)})"
class Failure():
def __init__(self, reason: str) -> None:
self._reason = reason
def reason(self) -> str:
return self._reason
def __str__(self) -> str:
return f"parse failed: {self._reason}"
ParseResult: TypeAlias = Success[T] | Failure
class Parser(ABC, Generic[T]):
@abstractmethod
def parse(self, s: str) -> ParseResult[T]:
pass
def map(self, f: Callable[[T], U]) -> "Parser[U]":
return pmap(self, f)
def and_then(self, p: "Parser[U]") -> "Parser[tuple[T, U]]":
return and_then(self, p)
def functional_parser(f: Callable[[str], ParseResult[T]]) -> Parser[T]:
class FuncParser(Parser[U]):
def __init__(self, f: Callable[[str], ParseResult[T]]):
self._func = f
def parse(self, s: str) -> ParseResult[T]:
return self._func(s)
return FuncParser(f)
def literal(lit: str) -> Parser[str]:
def parse(s: str) -> ParseResult[str]:
if s.startswith(lit):
return Success(s[:len(lit)], s[len(lit):])
return Failure(f"\"{s}\" did not match \"{lit}\"")
return functional_parser(parse)
def pair(left: Parser[T], right: Parser[U]) -> Parser[tuple[T, U]]:
def parse(s: str) -> ParseResult[tuple[T, U]]:
a = left.parse(s)
if isinstance(a, Success):
b = right.parse(a.rest())
if isinstance(b, Success):
return Success((a.result(), b.result()), b.rest())
else:
return b
else:
return a
return functional_parser(parse)
def pmap(p: Parser[T], f: Callable[[T], U]) -> Parser[U]:
def parser(s: str) -> ParseResult[U]:
t = p.parse(s)
if isinstance(t, Success):
return Success(f(t.result()), t.rest())
else:
return cast(ParseResult[U], t)
return functional_parser(parser)
def and_then(a: Parser[T], b: Parser[U]) -> Parser[tuple[T, U]]:
def parse(s) -> ParseResult[tuple[T, U]]:
t = a.parse(s)
if isinstance(t, Success):
u = b.parse(t.rest())
if isinstance(u, Success):
return Success((t.result(), u.result()), u.rest())
else:
return u
else:
return t
return functional_parser(parse)
def left(left: Parser[T], right: Parser[Any]) -> Parser[T]:
def extract_left(result: tuple[T, Any]) -> T:
return result[0]
return pmap(pair(left, right), extract_left)
def right(left: Parser[Any], right: Parser[T]) -> Parser[T]:
def extract_right(result: tuple[Any, T]) -> T:
return result[1]
return pair(left, right).map(extract_right)
def alt(parsers: list[Parser[T]]) -> Parser[T]:
def parse(s: str) -> ParseResult[T]:
for p in parsers:
result = p.parse(s)
if isinstance(result, Success):
return result
return Failure(f"did not match any alternatives: \"{s}\"")
return functional_parser(parse)
def zero_or_more(p: Parser[T]) -> Parser[list[T]]:
def parse(s) -> ParseResult[list[T]]:
rest = s
result = []
while True:
r = p.parse(rest)
if isinstance(r, Success):
result.append(r.result())
rest = r.rest()
else:
return Success(result, rest)
return functional_parser(parse)
def take_while0(f: Callable[[str], bool]) -> Parser[str]:
def parse(s: str) -> ParseResult[str]:
index = 0
while index < len(s) and f(s[index]):
index += 1
return Success(s[:index], s[index:])
return functional_parser(parse)
def take_while1(f: Callable[[str], bool]) -> Parser[str]:
def parse(s: str) -> ParseResult[str]:
index = 0
while index < len(s) and f(s[index]):
index += 1
if index == 0:
return Failure(f"did not match at least one character: \"{s}\"")
return Success(s[:index], s[index:])
return functional_parser(parse)
def eof() -> Parser[str]:
def parse(s: str) -> ParseResult[str]:
if len(s) == 0:
return Success("", "")
return Failure(f"expected EOF, got \"{s}\"")
return functional_parser(parse)
def csv() -> Parser[list[list[str]]]:
return pair(record(),
zero_or_more(trailing_record())).map(lambda x: [x[0]] + x[1])
def record() -> Parser[list[str]]:
return pair(field(),
zero_or_more(trailing_field())).map(lambda x: [x[0]] + x[1])
def trailing_record() -> Parser[list[str]]:
return right(crlf(), record())
def trailing_field() -> Parser[str]:
return right(comma(), field())
def field() -> Parser[str]:
return alt([escaped(), non_escaped()])
def escaped() -> Parser[str]:
# The take_while1/take_while0 song and dance is necessary to prevent
# infinite recursion while reading textdata. To avoid left recursion, we
# make sure to match non-empty values up front (wrapped in a zero_or_more)
# and then optionally finish off with empty text data.
inner_value = pair(
zero_or_more(
alt([
take_while1(is_valid_textdata),
comma(),
carriage_return(),
line_feed(),
quoted_dquote()
])).map(lambda xs: "".join(xs)),
take_while0(is_valid_textdata)).map(lambda xs: "".join(xs))
return left(right(dquote(), inner_value), dquote())
def non_escaped() -> Parser[str]:
return take_while0(is_valid_textdata)
def comma() -> Parser[str]:
return literal(",")
def carriage_return() -> Parser[str]:
return literal("\r")
def dquote() -> Parser[str]:
return literal("\"")
def quoted_dquote() -> Parser[str]:
# The underlying value should only contain a single dquote.
return pair(dquote(), dquote()).map(lambda x: "\"")
def line_feed() -> Parser[str]:
return literal("\n")
def crlf() -> Parser[RecordSeparator]:
return pair(carriage_return(),
line_feed()).map(lambda _: RecordSeparator.VALUE)
def is_valid_textdata(c: str) -> bool:
n = ord(c)
if n >= 0x20 and n <= 0x21:
return True
if n >= 0x23 and n <= 0x2b:
return True
if n >= 0x2d and n <= 0x7e:
return True
return False
def main():
with open("/tmp/input.csv") as f:
s = f.read()
print(csv().parse(s))
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment