Created
June 11, 2023 03:08
-
-
Save esnya/b193ee5d0d4ddcaef095bab63bacc425 to your computer and use it in GitHub Desktop.
audiocraftで遊んだやつ。おおむね無限につなぎつつけるやつ。
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import asyncio | |
from contextlib import contextmanager | |
from functools import cached_property | |
from typing import Any, Callable, Generator, Optional, Tuple | |
import pyaudio | |
import torch | |
from audiocraft.data.audio import audio_write | |
from audiocraft.models.musicgen import MusicGen | |
from torchaudio.backend.soundfile_backend import load | |
from websockets.server import WebSocketServerProtocol, serve | |
class MusicGenerator: | |
def __init__(self, model: MusicGen): | |
self.model = model | |
self.wav_tensor: Optional[torch.Tensor] = None | |
self.stream = None | |
self.description = "" | |
self.volume = 0.5 | |
self.melody: Optional[Tuple[torch.Tensor, int]] = None | |
self.prompt_duration = 1.0 | |
self.output_device_index: Optional[int] = None | |
self._stop = asyncio.Event() | |
@cached_property | |
def _commands(self) -> dict[str, Callable]: | |
return { | |
"/exit": lambda _: self._stop.set(), | |
"/duration": lambda input_str: self.model.set_generation_params( | |
duration=float(input_str.split(" ", 1)[1]) | |
), | |
"/volume": lambda input_str: setattr( | |
self, "volume", float(input_str.split(" ", 1)[1]) | |
), | |
"/melody": lambda input_str: self.set_melody(input_str.split(" ", 1)[1]), | |
"/clear_melody": lambda _: setattr(self, "melody", None), | |
"/save": lambda input_str: asyncio.to_thread( | |
audio_write, | |
input_str.split(" ", 1)[1], | |
self.wav_tensor.cpu(), | |
self.model.sample_rate, | |
strategy="loudness", | |
) | |
if self.wav_tensor | |
else None, | |
} | |
def set_melody(self, filepath: str): | |
self.melody = load(filepath) | |
@contextmanager | |
def _open_stream(self) -> Generator[pyaudio.Stream, Any, None]: | |
stream = pyaudio.PyAudio().open( | |
format=pyaudio.paFloat32, | |
channels=1, | |
rate=self.model.sample_rate, | |
output=True, | |
output_device_index=self.output_device_index, | |
) | |
self.stream = stream | |
try: | |
yield stream | |
finally: | |
self.stream.close() | |
async def _generate_loop(self): | |
while not self._stop.is_set(): | |
input_str: str = await asyncio.to_thread( | |
input, f"{self.description or 'Description'}> " | |
) | |
command = input_str.lower().split(" ", 1)[0] | |
if command in self._commands: | |
try: | |
self._commands[command](input_str) | |
except Exception as e: | |
print(e) | |
else: | |
self.description = input_str.strip() or self.description | |
# await self._generate() | |
async def _generate(self): | |
if self.melody: | |
try: | |
melody, melody_sr = self.melody | |
output = await asyncio.to_thread( | |
self.model.generate_with_chroma, | |
[self.description], | |
melody.expand(1, -1, -1), | |
melody_sr, | |
progress=True, | |
) | |
except SystemError as e: | |
print(e) | |
self.melody = None | |
return await self._generate() | |
else: | |
output = await asyncio.to_thread( | |
self.model.generate, [self.description], progress=True | |
) | |
self.wav_tensor = output.cpu() | |
async def _play_loop(self, stream: pyaudio.Stream): | |
while stream.is_active() and not self._stop.is_set(): | |
if self.wav_tensor is None: | |
await asyncio.sleep(0) | |
continue | |
await asyncio.to_thread(stream.write, self.wav_tensor.numpy().tobytes()) | |
self._stop.set() | |
async def _continuation_loop(self): | |
while not self._stop.is_set(): | |
if self.wav_tensor is None: | |
await asyncio.sleep(0) | |
continue | |
prompt_frames = int(self.prompt_duration * self.model.sample_rate) | |
# self.model.generation_params["remove_prompts"] = True | |
output = await asyncio.to_thread( | |
self.model.generate_continuation, | |
self.wav_tensor[:, :, -prompt_frames:], | |
self.model.sample_rate, | |
[self.description], | |
) | |
self.wav_tensor = output.cpu() | |
async def start(self): | |
with self._open_stream() as stream: | |
if self.description: | |
await self._generate() | |
await asyncio.gather( | |
asyncio.create_task(self._generate_loop()), | |
asyncio.create_task(self._play_loop(stream)), | |
asyncio.create_task(self._continuation_loop()), | |
) | |
def set_pad_mode_recursive(target, pad_mode: str, _done_list=set(), name="$"): | |
if target in _done_list: | |
return | |
_done_list.add(target) | |
if hasattr(target, "pad_mode"): | |
print(f"{name}({target}).pad_mode = {pad_mode}") | |
target.pad_mode = pad_mode | |
for key in dir(target): | |
child = getattr(target, key) | |
# print(key, child.__class__.__name__) | |
if isinstance(child, torch.nn.Module): | |
set_pad_mode_recursive(child, pad_mode, _done_list, f"{name}.{key}") | |
async def main(): | |
parser = argparse.ArgumentParser(description="Jukebox powered by AudioCraft") | |
parser.add_argument( | |
"--model", default="melody", type=str, help="Name of the pretrained model." | |
) | |
parser.add_argument( | |
"--device", default="cuda", type=str, help="Device to use for generation." | |
) | |
parser.add_argument( | |
"--description", | |
default="Simple Music", | |
type=str, | |
help="Initial description for the music generator.", | |
) | |
parser.add_argument( | |
"--duration", default=15, type=float, help="Duration for music generation." | |
) | |
parser.add_argument( | |
"--volume", default=0.5, type=float, help="Volume to play the generated music." | |
) | |
parser.add_argument( | |
"--melody", default=None, type=str, help="Path to the melody to use." | |
) | |
parser.add_argument( | |
"--continuous-overlap", | |
default=15, | |
type=float, | |
help="Overlap duration for continuous generation.", | |
) | |
parser.add_argument( | |
"--list-audio-devices", | |
action="store_true", | |
help="List all available audio devices.", | |
) | |
parser.add_argument( | |
"--output-device-index", | |
default=None, | |
type=int, | |
help="Index of the output device to use.", | |
) | |
args = parser.parse_args() | |
if args.list_audio_devices: | |
pa = pyaudio.PyAudio() | |
for i in range(pa.get_device_count()): | |
info = pa.get_device_info_by_index(i) | |
if info["maxOutputChannels"] == 0: | |
continue | |
print(info) | |
return | |
model = MusicGen.get_pretrained(args.model, args.device) | |
model.set_generation_params(duration=args.duration) | |
set_pad_mode_recursive(model, "circular") | |
generator = MusicGenerator(model) | |
generator.description = args.description | |
generator.volume = args.volume | |
generator.output_device_index = args.output_device_index | |
if args.melody: | |
generator.set_melody(args.melody) | |
async def handle_websocket(websocket: WebSocketServerProtocol): | |
description = await websocket.recv() | |
print(f"\n{description}") | |
generator.description = description | |
async with serve(handle_websocket, "localhost", 8001): | |
await generator.start() | |
if __name__ == "__main__": | |
asyncio.run(main()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment