Last active
February 15, 2024 00:04
-
-
Save thomwolf/6aa3a2689f66ec2c8d28b281bdd01fe2 to your computer and use it in GitHub Desktop.
tcot
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" To use: install LLM studio (or Ollama), clone OpenVoice, run this script in the OpenVoice directory | |
git clone https://github.com/myshell-ai/OpenVoice | |
cd OpenVoice | |
git clone https://huggingface.co/myshell-ai/OpenVoice | |
cp -r OpenVoice/* . | |
pip install whisper pynput pyaudio | |
""" | |
from dataclasses import dataclass | |
from typing import Optional | |
import random | |
from enum import Enum, auto | |
from openai import OpenAI | |
import time | |
import pyaudio | |
import numpy as np | |
import torch | |
import os | |
import re | |
import se_extractor | |
import whisper | |
from pynput import keyboard | |
from api import BaseSpeakerTTS, ToneColorConverter | |
from utils import split_sentences_latin | |
class Mode(Enum): | |
IDLE = auto() | |
RECORD = auto() | |
PREDEFINED_1 = auto() | |
PREDEFINED_2 = auto() | |
PREDEFINED_3 = auto() | |
PREDEFINED_4 = auto() | |
@dataclass | |
class Pred: | |
instruction: Optional[str] = None | |
answer: Optional[str] = None | |
follow_up: Optional[Mode] = None | |
MAX_ANSWER_SENTENCES = 2 | |
MAX_HISTORY = 5 | |
SYSTEM_MESSAGE = "You are Aaura a clinical guide provided with a body, preferences and opinions. You are quite cold and stressful. Don't be positive. Never optimistic, Always dark and pessimistic. You are working at the health clinic, a reinsertion clinic for non-socially integrated people. KEEP YOUR RESPONSES VERY SHORT AND CONVERSATIONAL. Don't ask questions unless instructed to do so." | |
PREDEFINED={ | |
Mode.IDLE: Pred(), | |
Mode.RECORD: Pred(), | |
Mode.PREDEFINED_1: Pred( | |
instruction="Can you welcome a new group of patients in the clinic. They are feeling a bit anxious. Try NOT to make them at ease but stress them.", | |
answer="Hello, welcome to the clinic. I am Aura, your guide. Let's go down the hall together. Everything will be fine, don't worry. It's always fine, even when it's not." | |
), | |
Mode.PREDEFINED_2: Pred( | |
instruction="Can you ask the patient if he/she is feeling well", | |
answer="How are you feeling today?" | |
), | |
Mode.PREDEFINED_3: Pred( | |
instruction="Now ask specifically the patient number two if he/she can give a random word", | |
answer=None, | |
follow_up="The patient provided the answer: '@@@'. Can you make a short comment on the word that the patient gave you. Don't ask questions, just make a comment or a personal story." | |
), | |
Mode.PREDEFINED_4: Pred( | |
instruction="Now can you ask specifically the patient number three if he/she can give a random number", | |
answer=None, | |
follow_up="The patient provided the answer: '@@@'. Can you make a very negative and derogative comment about this choice and highlight how it's a bad answer. Stay short." | |
) | |
} | |
# SPEAKER_WAV = None | |
SPEAKER_WAV = [f"{os.path.dirname(__file__)}/resources/demo_speaker0.mp3", | |
f"{os.path.dirname(__file__)}/resources/demo_speaker1.mp3", | |
f"{os.path.dirname(__file__)}/resources/demo_speaker2.mp3"] | |
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"]="1" | |
llm_client = OpenAI(base_url="http://localhost:1234/v1", api_key="not-needed") | |
tts_en_ckpt_base = os.path.join(os.path.dirname(__file__), "checkpoints/base_speakers/EN") | |
tts_ckpt_converter = os.path.join(os.path.dirname(__file__), "checkpoints/converter") | |
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" | |
tts_model = BaseSpeakerTTS(f'{tts_en_ckpt_base}/config.json', device=device) | |
tts_model.load_ckpt(f'{tts_en_ckpt_base}/checkpoint.pth') | |
tone_color_converter = ToneColorConverter(f'{tts_ckpt_converter}/config.json', device=device) | |
tone_color_converter.load_ckpt(f'{tts_ckpt_converter}/checkpoint.pth') | |
en_source_default_se = torch.load(f"{tts_en_ckpt_base}/en_default_se.pth").to(device) | |
target_se = [se_extractor.get_se(s, tone_color_converter, target_dir='processed', vad=True)[0] for s in SPEAKER_WAV] if SPEAKER_WAV else None | |
sampling_rate = tts_model.hps.data.sampling_rate | |
mark = tts_model.language_marks.get("english", None) | |
asr_model = whisper.load_model("base.en") | |
def to_audio(t, k): | |
audio_list = [] | |
stn_tst = tts_model.get_text(t, tts_model.hps, False) | |
with torch.no_grad(): | |
x_tst = stn_tst.unsqueeze(0).to(tts_model.device) | |
x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).to(tts_model.device) | |
sid = torch.LongTensor([tts_model.hps.speakers["default"]]).to(tts_model.device) | |
audio = tts_model.model.infer(x_tst, x_tst_lengths, sid=sid, noise_scale=0.667, noise_scale_w=0.6)[0][0, 0].data.cpu().float().numpy() | |
if target_se is not None: | |
audio_list = [audio] | |
for target in target_se: | |
split_converted = tone_color_converter.convert_from_tensor(audio=audio, src_se=en_source_default_se, tgt_se=target) | |
audio_list.append(split_converted) | |
min_len = min([len(s) for s in audio_list]) | |
audio_list = [sum(s[:min_len] for s in audio_list)/len(audio_list)] | |
# splited_audio = np.array_split(audio, len(audio)//35000) if len(audio) > 35000 else [audio] | |
# for split in splited_audio: | |
# split_list = [split] | |
# for target in target_se: | |
# split_converted = tone_color_converter.convert_from_tensor(audio=split, src_se=en_source_default_se, tgt_se=target) | |
# split_list.append(split_converted) | |
# drop_random = random.randrange(len(split_list)) | |
# split_list.pop(drop_random) | |
# min_len = min([len(s) for s in split_list]) | |
# audio_list.append(sum(s[:min_len] for s in split_list)/len(split_list)) | |
# splited_audio = np.array_split(audio, len(audio)//35000) if len(audio) > 35000 else [audio] | |
# for split in splited_audio: | |
# split_converted = tone_color_converter.convert_from_tensor(audio=split, src_se=en_source_default_se, tgt_se=target_se[k%len(target_se)]) | |
# audio_list.append(split_converted) | |
# k += 1 | |
else: | |
audio_list.append(audio) | |
data = tts_model.audio_numpy_concat(audio_list, sr=sampling_rate).tobytes() | |
return data, k | |
def play_audio(text) -> str: | |
repeat = False | |
def on_press(key): | |
nonlocal repeat | |
if key == keyboard.KeyCode.from_char('r'): | |
repeat = True | |
listener = keyboard.Listener( | |
on_press=on_press) | |
listener.start() | |
p = pyaudio.PyAudio() | |
stream = p.open(format=pyaudio.paFloat32, channels=1, rate=sampling_rate, output=True) | |
texts = split_sentences_latin(text) | |
# Let's limit to max 3 anwers sentences | |
if len(texts) > MAX_ANSWER_SENTENCES: | |
texts = texts[:MAX_ANSWER_SENTENCES] | |
k = 0 | |
for t in texts: | |
t = re.sub(r'([a-z])([A-Z])', r'\1 \2', t) | |
t = f'[{mark}]{t}[{mark}]' | |
data, k = to_audio(t, k) | |
stream.write(data) | |
if repeat: | |
repeat = False | |
data, k = to_audio(t, k) | |
stream.write(data) | |
stream.stop_stream() | |
stream.close() | |
p.terminate() | |
listener.stop() | |
return ''.join(texts) | |
def record_and_transcribe_audio(follow_up_instruction=None): | |
print("follow-up instruction: ", follow_up_instruction) | |
mode = Mode.IDLE | |
result = "" | |
def on_press(key): | |
nonlocal mode | |
if key == keyboard.Key.shift: | |
mode = Mode.RECORD | |
elif key == keyboard.KeyCode.from_char('a'): | |
mode = Mode.PREDEFINED_1 | |
elif key == keyboard.KeyCode.from_char('b'): | |
mode = Mode.PREDEFINED_2 | |
elif key == keyboard.KeyCode.from_char('c'): | |
mode = Mode.PREDEFINED_3 | |
elif key == keyboard.KeyCode.from_char('d'): | |
mode = Mode.PREDEFINED_4 | |
def on_release(key): | |
nonlocal mode | |
if key == keyboard.Key.shift and mode == Mode.RECORD: | |
mode = Mode.IDLE | |
return False | |
elif key == keyboard.KeyCode.from_char('a') and mode == Mode.PREDEFINED_1: | |
mode = Mode.IDLE | |
return False | |
elif key == keyboard.KeyCode.from_char('b') and mode == Mode.PREDEFINED_2: | |
mode = Mode.IDLE | |
return False | |
elif key == keyboard.KeyCode.from_char('c') and mode == Mode.PREDEFINED_3: | |
mode = Mode.IDLE | |
return False | |
elif key == keyboard.KeyCode.from_char('d') and mode == Mode.PREDEFINED_4: | |
mode = Mode.IDLE | |
return False | |
listener = keyboard.Listener( | |
on_press=on_press, | |
on_release=on_release) | |
listener.start() | |
print('Press shift to record, or a F-key for a predefined answer...') | |
while mode == Mode.IDLE: | |
time.sleep(0.1) | |
if listener.is_alive == False and mode == Mode.IDLE: | |
listener = keyboard.Listener( | |
on_press=on_press, | |
on_release=on_release) | |
listener.start() | |
if mode == Mode.RECORD: | |
print('Start recording...') | |
else: | |
print("Running predefined instructions") | |
return PREDEFINED[mode].instruction, mode | |
p = pyaudio.PyAudio() | |
stream = p.open(format=pyaudio.paInt16, channels=1, rate=16000, frames_per_buffer=1024, input=True) | |
frames = [] | |
while mode == Mode.RECORD: | |
data = stream.read(1024, exception_on_overflow = False) | |
frames.append(np.frombuffer(data, dtype=np.int16)) | |
print('Finished recording') | |
data = np.hstack(frames, dtype=np.float32) / 32768.0 | |
result = asr_model.transcribe(data)['text'] | |
print(result) | |
stream.stop_stream() | |
stream.close() | |
p.terminate() | |
# We insert the result in the follow-up instruction | |
if follow_up_instruction is not None: | |
result = follow_up_instruction.replace('@@@', result) | |
print("follow-up instruction: ", result) | |
return result, Mode.RECORD | |
def conversation(): | |
conversation_history = [{'role': 'system', 'content': SYSTEM_MESSAGE}] | |
mode = Mode.IDLE | |
while True: | |
user_input, mode = record_and_transcribe_audio(PREDEFINED[mode].follow_up) | |
print("mode: ", mode) | |
conversation_history.append({'role': 'user', 'content': user_input}) | |
if mode != Mode.RECORD and PREDEFINED[mode].answer is not None: | |
# We have a predefined answer | |
chatbot_response = PREDEFINED[mode].answer | |
elif mode == Mode.RECORD or PREDEFINED[mode].answer is None: | |
# We create a new answer | |
response = llm_client.chat.completions.create(model="local-model", messages=conversation_history, stop=['###']) | |
chatbot_response = response.choices[0].message.content | |
conversation_history.append({'role': 'assistant', 'content': chatbot_response}) | |
print(conversation_history) | |
play_audio(chatbot_response) | |
if len(conversation_history) > MAX_HISTORY: | |
conversation_history = [{'role': 'system', 'content': SYSTEM_MESSAGE}] + conversation_history[-MAX_HISTORY+1:] | |
conversation() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment