Forked from stashlukj/bluesky_datagram_protocol.py
Last active
June 10, 2020 12:06
-
-
Save danielballan/078afb1243dc0b895c163e05f337f26c to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from asyncio import DatagramProtocol, gather, get_event_loop, sleep | |
from collections import deque | |
from functools import partial | |
from json import loads, dumps | |
from random import choices, random | |
from ophyd import Device | |
from ophyd.status import DeviceStatus, Status | |
CLIENT_PORT = 9870 | |
SERVER_PORT = 9871 | |
class ServerProtocol(DatagramProtocol): | |
def __init__(self, N): | |
self.transport = None | |
super().__init__() | |
# arbitrary devices show up on network | |
self.streaming = {key: False for key in choices(range(1, 100), k=N)} | |
async def emit(self): | |
while True: | |
for dev, on in self.streaming.items(): | |
m = {"dev": dev} | |
if on: | |
m["value"] = random() | |
self.transport.sendto(dumps(m).encode(), ("localhost", CLIENT_PORT)) | |
await sleep(1) | |
def connection_made(self, transport): | |
self.transport = transport | |
get_event_loop().create_task(self.emit()) | |
def datagram_received(self, data, addr): | |
m = loads(data.decode()) | |
if m["dev"] in self.streaming: | |
self.streaming[m["dev"]] = m["state"] | |
class ClientProtocol(DatagramProtocol): | |
def __init__(self, interpreter): | |
self.i = interpreter | |
self.transport = None | |
super().__init__() | |
def connection_made(self, transport): | |
self.transport = transport | |
def datagram_received(self, data, addr): | |
# ip_addr, _ = addr | |
m = loads(data.decode()) | |
if m["dev"] not in self.i.devs: | |
self.i.devs[m["dev"]] = MockDevice("mock", m["dev"], self) | |
self.i.devs[m["dev"]].on_message(m) | |
class MockDevice(Device): | |
def __init__(self, name, dev, protocol, **kwargs): | |
self.dev = dev | |
self.protocol = protocol | |
self._data = deque() | |
self._stopping = Status(done=False) | |
self._starting = Status(done=False) | |
super().__init__(name=name, **kwargs) | |
def on_message(self, message): | |
print(self, message) | |
if "value" in message: | |
if self._starting.done == False: | |
self._starting.set_finished() | |
self._data.append(message["value"]) | |
else: | |
if self._stopping.done == False: | |
self._stopping.set_finished() | |
def collect(self): | |
yield from self._data | |
def kickoff(self): | |
self._data = deque() | |
self._starting = Status() | |
self.protocol.transport.sendto( | |
dumps({"dev": self.dev, "state": True}).encode(), ("localhost", SERVER_PORT) | |
) | |
print("starting...") | |
self._starting.wait(5) | |
print("started.") | |
return DeviceStatus(device=self) | |
def complete(self): | |
self._stopping = Status() | |
self.protocol.transport.sendto( | |
dumps({"dev": self.dev, "state": False}).encode(), | |
("localhost", SERVER_PORT), | |
) | |
self._stopping.wait(5) | |
return DeviceStatus(device=self) | |
class DeviceManager: | |
def __init__(self, event_loop): | |
self.event_loop = event_loop | |
self.devs = {} | |
async def run(self): | |
self.transport, self.protocol = await self.event_loop.create_datagram_endpoint( | |
partial(ClientProtocol, interpreter=self), | |
local_addr=("localhost", CLIENT_PORT), | |
) | |
def on_message(self, message): | |
print(self, message) | |
self.devs[message["dev"]].on_message(message) | |
if __name__ == "__main__": | |
from ophyd.sim import det | |
from bluesky import RunEngine | |
from bluesky.callbacks.best_effort import BestEffortCallback | |
from bluesky.log import config_bluesky_logging | |
from bluesky.plans import count | |
from bluesky.preprocessors import fly_during_wrapper | |
from bluesky.run_engine import get_bluesky_event_loop | |
# create the mock server endpoint | |
async def start_server(): | |
server_transport, server_protocol = await LOOP.create_datagram_endpoint( | |
partial(ServerProtocol, 3), local_addr=("localhost", SERVER_PORT) | |
) | |
LOOP = get_bluesky_event_loop() | |
MANAGER = DeviceManager(LOOP) | |
LOOP.create_task(MANAGER.run()) | |
LOOP.create_task(start_server()) | |
# This will call bluesky.run_engine._ensure_event_loop_running | |
# which will run_forever LOOP on a background thread. | |
RE = RunEngine({}) | |
# Send all metadata/data captured to the BestEffortCallback. | |
bec = BestEffortCallback() | |
RE.subscribe(bec) | |
try: | |
RE(fly_during_wrapper(count([det], num=10, delay=1), MANAGER.devs.values())) | |
except KeyboardInterrupt: | |
pass | |
finally: | |
LOOP.stop() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment