Last active
February 1, 2024 08:14
-
-
Save voluntas/88bba0157546836a04b51ef9b82a38e6 to your computer and use it in GitHub Desktop.
笑い男
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import json | |
import math | |
import os | |
from pathlib import Path | |
import cv2 | |
import mediapipe as mp | |
import numpy as np | |
from PIL import Image, ImageSequence | |
from sora_sdk import Sora | |
class LogoStreamer: | |
def __init__( | |
self, | |
signaling_urls, | |
role, | |
channel_id, | |
metadata, | |
camera_id, | |
video_width, | |
video_height, | |
): | |
self.mp_face_detection = mp.solutions.face_detection | |
self.sora = Sora() | |
self.video_source = self.sora.create_video_source() | |
self.connection = self.sora.create_connection( | |
signaling_urls=signaling_urls, | |
role=role, | |
channel_id=channel_id, | |
metadata=metadata, | |
video_source=self.video_source, | |
) | |
self.connection.on_disconnect = self.on_disconnect | |
self.video_capture = cv2.VideoCapture(camera_id) | |
if video_width is not None: | |
self.video_capture.set(cv2.CAP_PROP_FRAME_WIDTH, video_width) | |
if video_height is not None: | |
self.video_capture.set(cv2.CAP_PROP_FRAME_HEIGHT, video_height) | |
self.running = True | |
# GIFを読み込む | |
self.load_gif(Path(__file__).parent.joinpath("img_mark_04.gif")) | |
def load_gif(self, filepath): | |
gif = Image.open(filepath) | |
self.gif_frames = [] | |
# GIFの各フレームを処理 | |
for frame in ImageSequence.Iterator(gif): | |
# フレームをRGBAモードに変換して透過情報を保持 | |
rgba_frame = frame.convert("RGBA") | |
self.gif_frames.append(rgba_frame) | |
self.current_frame = 0 | |
def get_next_gif_frame(self): | |
frame = self.gif_frames[self.current_frame] | |
self.current_frame = (self.current_frame + 1) % len(self.gif_frames) | |
return frame | |
def on_disconnect(self, error_code, message): | |
print(f"Sora から切断されました: error_code='{error_code}' message='{message}'") | |
self.running = False | |
def run(self): | |
self.connection.connect() | |
try: | |
# 顔検出を用意する | |
with self.mp_face_detection.FaceDetection( | |
model_selection=0, min_detection_confidence=0.5 | |
) as face_detection: | |
while self.running and self.video_capture.isOpened(): | |
self.run_one_frame(face_detection) | |
except KeyboardInterrupt: | |
pass | |
finally: | |
self.connection.disconnect() | |
self.video_capture.release() | |
def run_one_frame(self, face_detection): | |
# フレームを取得する | |
success, frame = self.video_capture.read() | |
if not success: | |
return | |
# 高速化のための設定 | |
frame.flags.writeable = False | |
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
# mediapipe で顔を検出する | |
results = face_detection.process(frame) | |
frame_height, frame_width, _ = frame.shape | |
pil_image = Image.fromarray(frame) | |
if results.detections: | |
for detection in results.detections: | |
location = detection.location_data | |
if not location.HasField("relative_bounding_box"): | |
continue | |
bb = location.relative_bounding_box | |
# 逆正規化を行う | |
w_px = math.floor(bb.width * frame_width) | |
h_px = math.floor(bb.height * frame_height) | |
x_px = min(math.floor(bb.xmin * frame_width), frame_width - 1) | |
y_px = min(math.floor(bb.ymin * frame_height), frame_height - 1) | |
# 検出領域を調整 | |
fixed_w_px = math.floor(w_px * 2.6) | |
fixed_h_px = math.floor(h_px * 2.6) | |
fixed_x_px = max(0, math.floor(x_px - (fixed_w_px - w_px) / 2)) | |
fixed_y_px = max(0, math.floor(y_px - (fixed_h_px - h_px) / 1.5)) | |
# GIFフレームを取得してリサイズ | |
gif_frame = self.get_next_gif_frame().resize((fixed_w_px, fixed_h_px)) | |
# リサイズしたフレームをPILイメージに合成 | |
pil_image.paste(gif_frame, (fixed_x_px, fixed_y_px), gif_frame) | |
frame.flags.writeable = True | |
frame = np.array(pil_image) | |
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) | |
# WebRTC に渡す | |
self.video_source.on_captured(frame) | |
return | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
# 必須引数 | |
default_signaling_urls = os.getenv("SORA_SIGNALING_URLS") | |
parser.add_argument( | |
"--signaling-urls", | |
default=default_signaling_urls, | |
type=str, | |
nargs="+", | |
required=not default_signaling_urls, | |
help="シグナリング URL", | |
) | |
default_channel_id = os.getenv("SORA_CHANNEL_ID") | |
parser.add_argument( | |
"--channel-id", | |
default=default_channel_id, | |
required=not default_channel_id, | |
help="チャネルID", | |
) | |
# オプション引数 | |
parser.add_argument("--metadata", help="メタデータ JSON") | |
parser.add_argument( | |
"--camera-id", type=int, default=0, help="cv2.VideoCapture() に渡すカメラ ID" | |
) | |
parser.add_argument( | |
"--video-width", | |
type=int, | |
default=os.getenv("SORA_VIDEO_WIDTH"), | |
help="入力カメラ映像の横幅のヒント", | |
) | |
parser.add_argument( | |
"--video-height", | |
type=int, | |
default=os.getenv("SORA_VIDEO_HEIGHT"), | |
help="入力カメラ映像の高さのヒント", | |
) | |
args = parser.parse_args() | |
metadata = None | |
if args.metadata: | |
metadata = json.loads(args.metadata) | |
streamer = LogoStreamer( | |
signaling_urls=args.signaling_urls, | |
role="sendonly", | |
channel_id=args.channel_id, | |
metadata=args.metadata, | |
camera_id=args.camera_id, | |
video_height=args.video_height, | |
video_width=args.video_width, | |
) | |
streamer.run() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment