Last active
September 13, 2023 12:04
-
-
Save agoose77/0e9b12a1c04afe61bce1c1c96ec5e3a1 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import annotations | |
import uproot.reading | |
import asyncio | |
import struct | |
import urllib.request | |
from collections.abc import Callable | |
from typing import Protocol, Final, NamedTuple, TypeVar, TypeAlias, Any | |
import aiohttp | |
BEGIN_CHUNK_SIZE = 403 | |
FILE_HEADER_FIELDS_HEAD = struct.Struct(">4sii") | |
FILE_HEADER_FIELDS_TAIL_SMALL = struct.Struct("iiiiiBiiiH16s") | |
FILE_HEADER_FIELDS_TAIL_BIG = struct.Struct("qqiiiBiqiH16s") | |
################### State machines ############################################ | |
Event: TypeAlias = Any | |
TransitionEvent = TypeVar("TransitionEvent", bound=Event) | |
State: TypeAlias = """Callable[[TransitionEvent], Event]""" | |
TransitionDict: TypeAlias = ( | |
"dict[tuple[State[Any], type[TransitionEvent]], State[TransitionEvent]]" | |
) | |
class StateMachine: | |
transitions: TransitionDict | |
current_state: State = None | |
def handle_incoming_event(self, event: Event) -> State: | |
self.current_state = self.transitions[self.current_state, type(event)] | |
return self.current_state.__get__(self)(event) | |
################### Requests and Responses #################################### | |
class ReadHandleRequest(NamedTuple): | |
handle: Any | |
start: int | |
stop: int | |
class ReadURLRequest(NamedTuple): | |
url: str | |
start: int | |
stop: int | |
class ReadResponse(NamedTuple): | |
payload: bytes | |
class TerminateRequest(NamedTuple): | |
result: Any | |
class InitResponse: | |
... | |
class ByteRangeSource(Protocol): | |
def read(self, start: int, stop: int) -> Any: | |
... | |
################### IO Sources ################################################ | |
class HTTPSource(ByteRangeSource): | |
def __init__(self, url): | |
self._url = url | |
def read(self, start: int, stop: int) -> ReadURLRequest: | |
return ReadURLRequest(self._url, start, stop) | |
class FileSource(ByteRangeSource): | |
def __init__(self, handle): | |
self._handle = handle | |
def read(self, start: int, stop: int) -> ReadHandleRequest: | |
return ReadHandleRequest(self._handle, start, stop) | |
################### TFile reader ############################################## | |
class TFileReader(StateMachine): | |
def __init__(self, source: ByteRangeSource): | |
self._source = source | |
def state_open_phase_1_init(self, _) -> Event: | |
return self._source.read(0, BEGIN_CHUNK_SIZE) | |
def state_open_phase_2_parse_header(self, response: ReadResponse): | |
data = memoryview(response.payload) | |
tfile = uproot.reading.ReadOnlyFile.__new__(uproot.reading.ReadOnlyFile) | |
tfile._file_path = "<BOGUS PATH>" | |
magic, tfile._fVersion, tfile._fBEGIN = FILE_HEADER_FIELDS_HEAD.unpack_from(data) | |
is_64_bit = tfile._fVersion >= 1000000 | |
( | |
tfile._fEND, | |
tfile._fSeekFree, | |
tfile._fNbytesFree, | |
tfile._nfree, | |
tfile._fNbytesName, | |
tfile._fUnits, | |
tfile._fCompress, | |
tfile._fSeekInfo, | |
tfile._fNbytesInfo, | |
tfile._fUUID_version, | |
tfile._fUUID, | |
) = ( | |
FILE_HEADER_FIELDS_TAIL_BIG if is_64_bit else FILE_HEADER_FIELDS_TAIL_SMALL | |
).unpack_from( | |
data, FILE_HEADER_FIELDS_HEAD.size | |
) | |
if magic != b"root": | |
raise ValueError(f"""not a ROOT file: first four bytes are {magic!r}""") | |
return TerminateRequest(tfile) | |
transitions: Final[TransitionDict] = { | |
(None, InitResponse): state_open_phase_1_init, | |
(state_open_phase_1_init, ReadResponse): state_open_phase_2_parse_header, | |
} | |
################### Event loops ############################################### | |
def sync_loop(fsm: StateMachine, response): | |
while True: | |
request = fsm.handle_incoming_event(response) | |
match request: | |
# File-like request | |
case ReadHandleRequest(handle, start, stop): | |
handle.seek(start) | |
data = handle.read(stop - start) | |
response = ReadResponse(data) | |
# URL request | |
case ReadURLRequest(url, start, stop): | |
url_request = urllib.request.Request(url) | |
url_request.add_header("Range", f"bytes={start}-{stop-1}") | |
url_response = urllib.request.urlopen(url_request) | |
response = ReadResponse(url_response.read()) | |
# Exit request | |
case TerminateRequest(result): | |
return result | |
case _: | |
raise TypeError | |
def _read_handle_blocking(handle, start, stop): | |
handle.seek(start) | |
return handle.read(stop - start) | |
async def async_loop(fsm: StateMachine, response): | |
async with aiohttp.ClientSession() as session: | |
while True: | |
request = fsm.handle_incoming_event(response) | |
match request: | |
# File-like request | |
case ReadHandleRequest(handle, start, stop): | |
data = await asyncio.to_thread( | |
_read_handle_blocking, handle, start, stop | |
) | |
response = ReadResponse(data) | |
# URL request | |
case ReadURLRequest(url, start, stop): | |
async with session.get( | |
url, headers={"Range": f"bytes={start}-{stop-1}"} | |
) as url_response: | |
response = ReadResponse(await url_response.read()) | |
# Exit request | |
case TerminateRequest(result): | |
return result | |
case _: | |
raise TypeError | |
def test_sync(): | |
f = open("/home/angus/Downloads/SMHiggsToZZTo4L(1).root", "rb") | |
reader = TFileReader(FileSource(f)) | |
response = sync_loop(reader, InitResponse()) | |
reader = TFileReader( | |
HTTPSource( | |
"https://github.com/jpivarski-talks/2022-09-12-pyhep22-awkward-combinatorics/raw/main/data/SMHiggsToZZTo4L.root" | |
) | |
) | |
response_2 = sync_loop(reader, InitResponse()) | |
print(response, response_2) | |
def test_async(): | |
async def loop(): | |
f = open("/home/angus/Downloads/SMHiggsToZZTo4L(1).root", "rb") | |
reader = TFileReader(FileSource(f)) | |
response_coro = async_loop(reader, InitResponse()) | |
reader = TFileReader( | |
HTTPSource( | |
"https://github.com/jpivarski-talks/2022-09-12-pyhep22-awkward-combinatorics/raw/main/data/SMHiggsToZZTo4L.root" | |
) | |
) | |
response_2_coro = async_loop(reader, InitResponse()) | |
response, response_2 = await asyncio.gather(response_coro, response_2_coro) | |
print(response, response_2) | |
asyncio.run(loop()) | |
if __name__ == "__main__": | |
print("1") | |
test_sync() | |
print("2") | |
test_async() | |
print("Done") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment