Last active
February 25, 2024 16:13
-
-
Save litagin02/f07a5d7217c9efa4918de0812fd99cd3 to your computer and use it in GitHub Desktop.
Bert-VITS2のモデルマージするやつ(声音・感情表現それぞれを取っ替えたり混ぜたり)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import gradio as gr | |
import torch | |
from infer import get_net_g, infer | |
import utils | |
voice_keys = ["dec", "flow"] | |
speech_style_keys = ["enc_p"] | |
tempo_keys = ["sdp", "dp"] | |
models_dir = "merge" | |
model_list = [ | |
os.path.join(models_dir, f) for f in os.listdir(models_dir) if f.endswith(".pth") | |
] | |
config_path = os.path.join(models_dir, "config.json") | |
def tts(model_path, text): | |
hps = utils.get_hparams_from_file(config_path) | |
speaker_name = next(iter(hps.data.spk2id.keys())) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
net_g = get_net_g( | |
model_path=model_path, | |
version=hps.version, | |
device=device, | |
hps=hps, | |
) | |
if hps.version == "2.1": | |
emotion = 0 | |
elif hps.version == "2.2": | |
emotion = "" | |
with torch.no_grad(): | |
audio = infer( | |
text=text, | |
sdp_ratio=0.2, | |
noise_scale=0.6, | |
noise_scale_w=0.8, | |
length_scale=1, | |
sid=speaker_name, | |
language="JP", | |
hps=hps, | |
net_g=net_g, | |
device=device, | |
emotion=emotion, | |
) | |
return (hps.data.sampling_rate, audio) | |
def merge_models( | |
model_path_a, model_path_b, voice_weight, speech_style_weight, tempo_weight | |
): | |
"""model Aを起点に、model Bの各要素を重み付けしてマージする""" | |
# モデルを読み込む | |
model_a = torch.load(model_path_a, map_location="cpu") | |
model_b = torch.load(model_path_b, map_location="cpu") | |
merged_model = model_a.copy() | |
for key in model_a["model"].keys(): | |
if any([key.startswith(prefix) for prefix in voice_keys]): | |
weight = voice_weight | |
elif any([key.startswith(prefix) for prefix in speech_style_keys]): | |
weight = speech_style_weight | |
elif any([key.startswith(prefix) for prefix in tempo_keys]): | |
weight = tempo_weight | |
else: | |
continue | |
merged_model["model"][key] = ( | |
model_a["model"][key] * (1 - weight) + model_b["model"][key] * weight | |
) | |
merged_model_path = os.path.join(models_dir, "merged_model.pth") | |
torch.save(merged_model, merged_model_path) | |
return merged_model_path | |
def refresh_models(): | |
model_list = [ | |
os.path.join(models_dir, f) | |
for f in os.listdir(models_dir) | |
if f.endswith(".pth") | |
] | |
return gr.Dropdown(choices=model_list), gr.Dropdown(choices=model_list) | |
initial_md = """ | |
# Bert-VITS2 モデルマージツール | |
2つのBert-VITS2モデルから、声質・話し方・話す速さを取り替えたり混ぜたりするやつです。 | |
確認したバージョンは2.1と2.2です。同じバージョン同士(2.1同士、2.2同士)でしか動きません。 | |
挙動としてはモデルAを起点にするので、configファイル等はモデルAのものを使ってください。もしかしたら若干モデルAに結果が寄りがちかもしれないけど多分そんなに変わらないです。 | |
## 使い方 | |
`merge`フォルダを作って、直下に混ぜたい`*.pth`ファイルを置いてください。またモデルAのconfig.jsonファイルも同じところに置いてください。 | |
""" | |
# Gradioインターフェースの作成 | |
with gr.Blocks() as demo: | |
gr.Markdown(initial_md) | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Row(): | |
model_a = gr.Dropdown( | |
label="モデルA (pthファイル)", choices=model_list, scale=2 | |
) | |
model_b = gr.Dropdown( | |
label="モデルB (pthファイル)", choices=model_list, scale=2 | |
) | |
refresh_button = gr.Button("モデルリストの再読み込み", scale=1) | |
refresh_button.click(fn=refresh_models, outputs=[model_a, model_b]) | |
voice_slider = gr.Slider( | |
label="声質", | |
value=0, | |
minimum=0, | |
maximum=1, | |
step=0.1, | |
) | |
speech_style_slider = gr.Slider( | |
label="話し方(抑揚・感情表現等)", | |
value=0, | |
minimum=0, | |
maximum=1, | |
step=0.1, | |
) | |
tempo_slider = gr.Slider( | |
label="話す速さ・リズム・テンポ", | |
value=0, | |
minimum=0, | |
maximum=1, | |
step=0.1, | |
) | |
merge_button = gr.Button("マージ") | |
merged_model_output = gr.Textbox(label="マージされたモデルのパス") | |
# マージボタンの動作を定義 | |
merge_button.click( | |
fn=merge_models, | |
inputs=[ | |
model_a, | |
model_b, | |
voice_slider, | |
speech_style_slider, | |
tempo_slider, | |
], | |
outputs=merged_model_output, | |
) | |
with gr.Column(): | |
gr.Markdown("### マージされたモデルでのTTSテスト") | |
input_text = gr.Textbox(label="テキスト") | |
play_button = gr.Button("再生") | |
tts_output = gr.Audio(label="オーディオ出力") | |
play_button.click( | |
fn=tts, | |
inputs=[merged_model_output, input_text], | |
outputs=tts_output, | |
) | |
demo.launch(inbrowser=True) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment