Created
November 20, 2023 19:43
-
-
Save pkit/0d1ef78a93ef1bc16c4e88d8547b6d72 to your computer and use it in GitHub Desktop.
The most basic URL server for ClickHouse. Can be used with the `URL()` engine: `URL('http://localhost:8555/', JSONEachRow)`
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import contextlib | |
import io | |
import json | |
import logging | |
import os | |
import socket | |
import sys | |
from functools import partial | |
from http.server import SimpleHTTPRequestHandler, ThreadingHTTPServer | |
from typing import Any, Iterator | |
PORT = int(os.environ.get("PORT", "8555")) | |
def _get_best_family(*address): | |
infos = socket.getaddrinfo( | |
*address, | |
type=socket.SOCK_STREAM, | |
flags=socket.AI_PASSIVE, | |
) | |
family, type, proto, canonname, sockaddr = next(iter(infos)) | |
return family, sockaddr | |
class ReqHandler(SimpleHTTPRequestHandler): | |
def do_POST(self) -> None: | |
if self.has_rows(): | |
for row in self.get_rows(): | |
print(row) | |
self.response(b"") # it's empty, ClickHouse ignores POST response data | |
def has_content_length(self) -> bool: | |
return int(self.headers.get("Content-Length", 0)) > 0 | |
def has_rows(self) -> bool: | |
return self.has_content_length() or "chunked" in self.headers.get("Transfer-Encoding", "") | |
def gen_chunks(self) -> Iterator[bytes]: | |
content_len = int(self.headers.get("Content-Length", 0)) | |
if content_len > 0: | |
yield self.rfile.read(content_len) | |
elif "chunked" in self.headers.get("Transfer-Encoding", ""): | |
while True: | |
line = self.rfile.readline().strip() | |
chunk_length = int(line, 16) | |
if chunk_length != 0: | |
yield self.rfile.read(chunk_length) | |
self.rfile.readline() | |
if chunk_length == 0: | |
return | |
else: | |
yield b"" | |
def get_rows(self, fmt: str = "JSONEachRow") -> Iterator[dict[str, Any]]: | |
if fmt == "JSONEachRow": | |
return self.get_rows_ndjson() | |
else: | |
raise ValueError(f"Unknown format: {fmt}") | |
def get_rows_ndjson(self) -> Iterator[dict[str, Any]]: | |
buf = io.StringIO() | |
for chunk in self.gen_chunks(): | |
lines: str = chunk.decode() | |
buf = io.StringIO(buf.read() + lines) | |
line = buf.readline().strip() | |
while line: | |
yield json.loads(line) | |
line = buf.readline().strip() | |
def response(self, body: bytes) -> None: | |
self.send_response(200) | |
self.send_header("Content-Type", "application/json") | |
self.send_header("Content-Length", str(len(body))) | |
self.end_headers() | |
self.wfile.write(body) | |
def main() -> None: | |
logging.basicConfig(level=logging.INFO) | |
class Server(ThreadingHTTPServer): | |
def server_bind(self) -> None: | |
# suppress exception when protocol is IPv4 | |
with contextlib.suppress(Exception): | |
self.socket.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 0) | |
return super().server_bind() | |
Server.address_family, addr = _get_best_family(None, PORT) | |
ReqHandler.protocol_version = "HTTP/1.1" | |
handler = partial(ReqHandler) | |
with Server(addr, handler) as httpd: | |
host, port = httpd.socket.getsockname()[:2] | |
url_host = f"[{host}]" if ":" in host else host | |
logging.info(f"Serving HTTP on {host} port {port} (http://{url_host}:{port}/) ...") | |
try: | |
httpd.serve_forever() | |
except KeyboardInterrupt: | |
print("\nKeyboard interrupt received, exiting.") | |
sys.exit(0) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment