Last active
December 2, 2023 05:50
-
-
Save josezy/ec7cddb7de894c299d61aad5682fc0c7 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Python script to run YOLOv5 object detection on a Janus WebRTC stream | |
# To test it, go to https://janus.conf.meetecho.com/videoroomtest.html and stream video | |
# Then run the following command: python yolov5-webrtc.py https://janus.conf.meetecho.com/janus | |
import aiohttp | |
import argparse | |
import asyncio | |
import cv2 | |
import logging | |
import random | |
import string | |
import time | |
import torch | |
from aiortc import RTCPeerConnection, RTCSessionDescription | |
from aiortc.contrib.media import MediaRecorder | |
from aiortc.mediastreams import MediaStreamError | |
pcs = set() | |
# Model | |
model = torch.hub.load("ultralytics/yolov5", "yolov5s") # or yolov5n - yolov5x6, custom | |
class YoloV5Detector: | |
def __init__(self): | |
self._track = None | |
self._track_task = None | |
self._detection_task = None | |
self._queue = asyncio.Queue(maxsize=1) | |
def __str__(self): | |
return "YoloV5Detector" | |
def addTrack(self, track): | |
self._track = track | |
async def start(self): | |
if self._track_task is None and self._detection_task is None: | |
self._track_task = asyncio.ensure_future(self.__run_track(self._track)) | |
self._detection_task = asyncio.ensure_future(self.__run_detection()) | |
async def stop(self): | |
if self._track_task is not None: | |
self._track_task.cancel() | |
self._track_task = None | |
if self._detection_task is not None: | |
self._detection_task.cancel() | |
self._detection_task = None | |
self._track = None | |
cv2.destroyAllWindows() | |
async def __run_track(self, track): | |
while True: | |
try: | |
frame = await track.recv() | |
except MediaStreamError: | |
logging.error("MediaStreamError: Track ended") | |
return | |
try: | |
self._queue.put_nowait(frame) | |
except asyncio.QueueFull: | |
pass | |
async def __run_detection(self): | |
while True: | |
try: | |
frame = self._queue.get_nowait() | |
except asyncio.QueueEmpty: | |
await asyncio.sleep(0.001) | |
continue | |
np_frame = frame.to_ndarray(format="bgr24") | |
results = model(np_frame) | |
cv2.imshow("YoloV5 Object Detection", results.render()[0]) | |
await asyncio.sleep(0.01) # Add a small delay to allow the UI thread to run | |
# Break the loop if 'q' key is pressed | |
if cv2.waitKey(1) & 0xFF == ord('q'): | |
break | |
def transaction_id(): | |
return "".join(random.choice(string.ascii_letters) for x in range(12)) | |
class JanusPlugin: | |
def __init__(self, session, url): | |
self._queue = asyncio.Queue() | |
self._session = session | |
self._url = url | |
async def send(self, payload): | |
message = {"janus": "message", "transaction": transaction_id()} | |
message.update(payload) | |
async with self._session._http.post(self._url, json=message) as response: | |
data = await response.json() | |
assert data["janus"] == "ack" | |
response = await self._queue.get() | |
assert response["transaction"] == message["transaction"] | |
return response | |
async def send_sync(self, payload): | |
message = {"janus": "message", "transaction": transaction_id()} | |
message.update(payload) | |
async with self._session._http.post(self._url, json=message) as response: | |
response = await response.json() | |
assert response["janus"] == "success" | |
assert response["transaction"] == message["transaction"] | |
return response | |
class JanusSession: | |
def __init__(self, url): | |
self._http = None | |
self._poll_task = None | |
self._plugins = {} | |
self._root_url = url | |
self._session_url = None | |
async def attach(self, plugin_name: str) -> JanusPlugin: | |
message = { | |
"janus": "attach", | |
"plugin": plugin_name, | |
"transaction": transaction_id(), | |
} | |
async with self._http.post(self._session_url, json=message) as response: | |
data = await response.json() | |
assert data["janus"] == "success" | |
plugin_id = data["data"]["id"] | |
plugin = JanusPlugin(self, self._session_url + "/" + str(plugin_id)) | |
self._plugins[plugin_id] = plugin | |
return plugin | |
async def create(self): | |
self._http = aiohttp.ClientSession() | |
message = {"janus": "create", "transaction": transaction_id()} | |
async with self._http.post(self._root_url, json=message) as response: | |
data = await response.json() | |
assert data["janus"] == "success" | |
session_id = data["data"]["id"] | |
self._session_url = self._root_url + "/" + str(session_id) | |
self._poll_task = asyncio.ensure_future(self._poll()) | |
async def destroy(self): | |
if self._poll_task: | |
self._poll_task.cancel() | |
self._poll_task = None | |
if self._session_url: | |
message = {"janus": "destroy", "transaction": transaction_id()} | |
async with self._http.post(self._session_url, json=message) as response: | |
data = await response.json() | |
assert data["janus"] == "success" | |
self._session_url = None | |
if self._http: | |
await self._http.close() | |
self._http = None | |
async def _poll(self): | |
while True: | |
params = {"maxev": 1, "rid": int(time.time() * 1000)} | |
async with self._http.get(self._session_url, params=params) as response: | |
data = await response.json() | |
if data["janus"] == "event": | |
plugin = self._plugins.get(data["sender"], None) | |
if plugin: | |
await plugin._queue.put(data) | |
else: | |
print(data) | |
async def subscribe(session, room, recorder, detector): | |
await session.create() | |
pc = RTCPeerConnection() | |
pcs.add(pc) | |
@pc.on("track") | |
async def on_track(track): | |
print("Track %s received" % track.kind) | |
if track.kind == "video": | |
if recorder is not None: | |
recorder.addTrack(track) | |
detector.addTrack(track) | |
# if track.kind == "audio": | |
# recorder.addTrack(track) | |
# subscribe | |
plugin = await session.attach("janus.plugin.videoroom") | |
response = await plugin.send_sync( | |
{"body": {"request": "listparticipants", "room": room}} | |
) | |
feed = response["plugindata"]["data"]["participants"][0]["id"] | |
response = await plugin.send( | |
{"body": {"request": "join", "ptype": "subscriber", "room": room, "feed": feed}} | |
) | |
# apply offer | |
await pc.setRemoteDescription( | |
RTCSessionDescription( | |
sdp=response["jsep"]["sdp"], type=response["jsep"]["type"] | |
) | |
) | |
# send answer | |
await pc.setLocalDescription(await pc.createAnswer()) | |
response = await plugin.send( | |
{ | |
"body": {"request": "start"}, | |
"jsep": { | |
"sdp": pc.localDescription.sdp, | |
"trickle": False, | |
"type": pc.localDescription.type, | |
}, | |
} | |
) | |
if recorder is not None: | |
await recorder.start() | |
await detector.start() | |
# exchange media for 10 minutes | |
print("Exchanging media") | |
await asyncio.sleep(600) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser(description="YOLOv5 using Janus WebRTC") | |
parser.add_argument("url", help="Janus root URL, e.g. http://localhost:8088/janus") | |
parser.add_argument( | |
"--room", | |
type=int, | |
default=1234, | |
help="The video room ID to join (default: 1234).", | |
), | |
parser.add_argument("--record-to", help="Write received media to a file."), | |
parser.add_argument("--verbose", "-v", action="count") | |
args = parser.parse_args() | |
if args.verbose: | |
logging.basicConfig(level=logging.DEBUG) | |
# create signaling and peer connection | |
session = JanusSession(args.url) | |
# create media sink | |
if args.record_to: | |
recorder = MediaRecorder(args.record_to) | |
else: | |
recorder = None | |
yolov5_detector = YoloV5Detector() | |
loop = asyncio.get_event_loop() | |
try: | |
loop.run_until_complete( | |
subscribe(session=session, room=args.room, recorder=recorder, detector=yolov5_detector) | |
) | |
except KeyboardInterrupt: | |
pass | |
finally: | |
if recorder is not None: | |
loop.run_until_complete(recorder.stop()) | |
loop.run_until_complete(yolov5_detector.stop()) | |
loop.run_until_complete(session.destroy()) | |
# close peer connections | |
coros = [pc.close() for pc in pcs] | |
loop.run_until_complete(asyncio.gather(*coros)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment