Skip to content

Instantly share code, notes, and snippets.

@seblin
Created December 14, 2023 16:38
Show Gist options
  • Save seblin/0a0825b9c7b70feb4ceb0f167f5f7dec to your computer and use it in GitHub Desktop.
Save seblin/0a0825b9c7b70feb4ceb0f167f5f7dec to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3
import fnmatch
import os
import re
import sys
from collections.abc import Iterable, Iterator, Mapping, Sequence
from dataclasses import dataclass
from itertools import batched, zip_longest
from locale import getlocale, LC_COLLATE, setlocale, strxfrm
from math import ceil
from pathlib import Path
from pprint import pprint
from shutil import get_terminal_size
from typing import Any
@dataclass
class FrameConfig:
spacer: str = " "
line_width: int | None = None
adjuster: str = "<"
def __post_init__(self) -> None:
if not isinstance(self.spacer, str):
raise TypeError("spacer should be a string")
if self.line_width is not None and self.line_width <= 0:
raise ValueError("line_width should be > 0 or None")
if self.adjuster not in "<^>":
raise ValueError("adjuster should be '<', '^' or '>'")
@dataclass
class Column(Iterable):
strings: Sequence[str]
max_width: int | None = None
adjuster: str = FrameConfig.adjuster
def __len__(self) -> int:
longest_width = max(map(len, self.strings), default=0)
if self.max_width is None:
return longest_width
return min(longest_width, self.max_width)
def __iter__(self) -> Iterator[Sequence[str]]:
width = len(self)
for string in self.strings:
rng = range(0, len(string), width)
yield [string[i: i + width] for i in rng]
def get_template(self) -> str:
return "{{:{0}{1}.{1}}}".format(
self.adjuster, len(self)
)
@dataclass
class Frame(Iterable):
columns: Sequence[Column]
spacer: str = FrameConfig.spacer
def __len__(self) -> int:
total_spacing = len(self.spacer) * (len(self.columns) - 1)
return sum(map(len, self.columns)) + total_spacing
def __iter__(self) -> Iterator[str]:
format_line = self.get_line_template().format
for segments in self.get_line_segments():
yield format_line(*segments).rstrip()
def __str__(self) -> str:
return "\n".join(self)
def get_line_template(self) -> str:
return self.spacer.join(col.get_template() for col in self.columns)
def get_line_segments(self) -> Iterator[Sequence[str]]:
for wrapped_items in zip_longest(*self.columns, fillvalue=""):
for line_segments in zip_longest(*wrapped_items, fillvalue=""):
yield line_segments
class FrameFactory(FrameConfig):
def query_line_width(self) -> int:
if self.line_width is None:
return get_terminal_size().columns
return self.line_width
def frame_for_nested(
self, nested_strings: Iterable[Sequence[str]]
) -> Frame:
columns = [
Column(strings, adjuster=self.adjuster)
for strings in nested_strings
]
return Frame(columns, self.spacer)
def frame_for_mapping(
self, mapping: Mapping[str, str]
) -> Frame:
frame = self.frame_for_nested(
[list(mapping.keys()), list(mapping.values())]
)
return self._get_shrinked(frame, reversed(frame.columns))
def frame_for_sequence(
self, strings: Sequence[str], max_columns: int
) -> Frame:
chunk_size = ceil(len(strings) / max_columns)
chunks = batched(strings, chunk_size)
return self.frame_for_nested(chunks)
def find_fitting_frame(
self, strings: Sequence[str]
) -> Frame:
line_width = self.query_line_width()
shortest_width = min(map(len, strings)) + len(self.spacer)
frame = self.frame_for_sequence(strings, line_width // shortest_width)
while len(frame.columns) > 1 and len(frame) > line_width:
frame = self.frame_for_sequence(strings, len(frame.columns) - 1)
return frame
def _get_shrinked(
self, frame: Frame,
allowed_columns: Iterable[Column],
min_column_width: int = 1
) -> Frame:
offset = len(frame) - self.query_line_width()
for column in allowed_columns:
if offset <= 0:
break
old_width = len(column)
column.max_width = max(old_width - offset, min_column_width)
offset -= old_width - column.max_width
return frame
def by_pattern(
strings: Mapping[str, str] | Iterable[str],
pattern: str,
ignore_case: bool = True
) -> Mapping[str, str] | Iterable[str]:
if isinstance(strings, Mapping):
keys = by_pattern(strings.keys(), pattern, ignore_case)
return {key: strings[key] for key in keys}
flags = re.IGNORECASE if ignore_case else 0
pattern = re.compile(fnmatch.translate(pattern), flags)
return filter(pattern.match, strings)
def get_sorted(
strings: Mapping[str, str] | Iterable[str]
) -> Mapping[str, str] | Iterable[str]:
if isinstance(strings, Mapping):
keys = get_sorted(strings.keys())
return {key: strings[key] for key in keys}
language, encoding = getlocale(LC_COLLATE)
if not language and not encoding:
setlocale(LC_COLLATE, "")
try:
sorted_strings = sorted(strings, key=strxfrm)
finally:
if not language and not encoding:
setlocale(LC_COLLATE, "C")
return sorted_strings
def preprocess(
values: Path | Mapping[Any, Any] | Iterable[Any],
pattern: str | None = None,
sort_strings: bool = False
) -> Mapping[str, str] | Sequence[str]:
if isinstance(values, Path):
values = values.iterdir()
if isinstance(values, Mapping):
strings = {str(k): str(v) for k, v in values.items()}
else:
strings = map(str, values)
if pattern is not None:
strings = by_pattern(strings, pattern)
if sort_strings:
strings = get_sorted(strings)
if not isinstance(strings, (Mapping, Sequence)):
strings = list(strings)
return strings
def columnize(
values: Mapping[Any, Any] | Iterable[Any],
spacer: str = FrameConfig.spacer,
line_width: int | None = FrameConfig.line_width,
adjuster: str = FrameConfig.adjuster,
pattern: str | None = None,
sort_strings: bool = False
) -> Frame:
strings = preprocess(values, pattern, sort_strings)
factory = FrameFactory(spacer, line_width, adjuster)
if isinstance(strings, Mapping):
return factory.frame_for_mapping(strings)
return factory.find_fitting_frame(strings)
def cprint(values, output_stream=sys.stdout, **params):
print(columnize(values, **params), file=output_stream)
def run_tests():
# cprint(range(50), sys.stderr)
# cprint(Path(), line_width=60, adjuster=">")
frame = columnize(os.environ, sort_strings=True, spacer=" | ")
print(frame)
print("Frame width:", len(frame))
# pprint(frame.columns)
if __name__ == "__main__":
run_tests()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment