Created
August 24, 2022 22:40
-
-
Save JD-P/fc29020d13581ad8cd11608f2b3555cb to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Install Instructions | |
# git clone https://github.com/harmonai-org/sample-generator | |
# git clone --recursive https://github.com/crowsonkb/v-diffusion-pytorch | |
# pip install ipywidgets==7.7.1 | |
# cd v-diffusion-pytorch | |
# pip install -r requirements.txt | |
# cd .. | |
# cd sample-generator | |
# pip install . | |
# cd .. | |
# pip install tqdm | |
# pip install matplotlib | |
# pip install requests | |
# mkdir models | |
# pip3 install torch==1.10.1+cu113 torchvision==0.11.2+cu113 torchaudio==0.10.1+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html | |
# python3 dance_diffusion.py --steps 200 --batch-size 32 --num 65 | |
from contextlib import contextmanager | |
from copy import deepcopy | |
import math | |
from pathlib import Path | |
import requests | |
import os | |
import sys | |
import gc | |
sys.path.append("v-diffusion-pytorch") | |
from diffusion import sampling | |
import torch | |
from torch import optim, nn | |
from torch.nn import functional as F | |
from torch.utils import data | |
from tqdm import trange | |
from einops import rearrange | |
import torchaudio | |
from audio_diffusion.models import DiffusionAttnUnet1D | |
import numpy as np | |
import argparse | |
import random | |
import matplotlib.pyplot as plt | |
import IPython.display as ipd | |
from audio_diffusion.utils import Stereo, PadCrop | |
from glob import glob | |
model_path = "./models" | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--model-name", type=str, default="unlocked-250k", | |
help='The model to download ["ravearchive-50k", "jmann-small-190k", "maestro-150k", "unlocked-250k"]') | |
parser.add_argument("-n", "--num", type=int, default=8, | |
help="The number of samples to generate during the run.") | |
parser.add_argument("--steps", type=int, default=100, | |
help="The number of diffusion steps per sample.") | |
parser.add_argument("-bs", "--batch-size", type=int, default=8, | |
help="The number of samples to generate per batch.") | |
args = parser.parse_args() | |
#@title Model code | |
class DiffusionUncond(nn.Module): | |
def __init__(self, global_args): | |
super().__init__() | |
self.diffusion = DiffusionAttnUnet1D(global_args, n_attn_layers = 4) | |
self.diffusion_ema = deepcopy(self.diffusion) | |
self.rng = torch.quasirandom.SobolEngine(1, scramble=True) | |
import matplotlib.pyplot as plt | |
import IPython.display as ipd | |
def load_to_device(path, sr): | |
audio, file_sr = torchaudio.load(path) | |
if sr != file_sr: | |
audio = torchaudio.transforms.Resample(file_sr, sr)(audio) | |
audio = audio.to(device) | |
return audio | |
def get_alphas_sigmas(t): | |
"""Returns the scaling factors for the clean image (alpha) and for the | |
noise (sigma), given a timestep.""" | |
return torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2) | |
def get_crash_schedule(t): | |
sigma = torch.sin(t * math.pi / 2) ** 2 | |
alpha = (1 - sigma ** 2) ** 0.5 | |
return alpha_sigma_to_t(alpha, sigma) | |
def t_to_alpha_sigma(t): | |
"""Returns the scaling factors for the clean image and for the noise, given | |
a timestep.""" | |
return torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2) | |
def alpha_sigma_to_t(alpha, sigma): | |
"""Returns a timestep, given the scaling factors for the clean image and for | |
the noise.""" | |
return torch.atan2(sigma, alpha) / math.pi * 2 | |
#@title Args | |
sample_size = 65536 | |
sample_rate = 48000 | |
latent_dim = 0 | |
args.sample_size = sample_size | |
args.sample_rate = sample_rate | |
args.latent_dim = latent_dim | |
from urllib.parse import urlparse | |
import hashlib | |
#@title Create the model | |
#@markdown If you have a custom fine-tuned model, choose "custom" above and enter a path to the model checkpoint here here | |
custom_ckpt_path = ''#@param {type: 'string'} | |
models_map = { | |
"ravearchive-50k": {'downloaded': False, | |
'sha': "4cd36b7071110c339649a118ca3ce9c5ac6d70e263f21ea9870d239eff5cb5e4", | |
'uri_list': ["https://model-server.zqevans2.workers.dev/ravearchive-uncond-50k.ckpt"], | |
'sample_rate': 48000, | |
'sample_size': 65536 | |
}, | |
"jmann-small-190k": {'downloaded': False, | |
'sha': "1e2a23a54e960b80227303d0495247a744fa1296652148da18a4da17c3784e9b", | |
'uri_list': ["https://model-server.zqevans2.workers.dev/jmann-small-190k.ckpt"], | |
'sample_rate': 48000, | |
'sample_size': 65536 | |
}, | |
"maestro-150k": {'downloaded': False, | |
'sha': "49d9abcae642e47c2082cec0b2dce95a45dc6e961805b6500204e27122d09485", | |
'uri_list': ["https://model-server.zqevans2.workers.dev/maestro-uncond-150k.ckpt"], | |
'sample_rate': 16000, | |
'sample_size': 65536 | |
}, | |
"unlocked-250k": {'downloaded': False, | |
'sha': "af337c8416732216eeb52db31dcc0d49a8d48e2b3ecaa524cb854c36b5a3503a", | |
'uri_list': ["https://model-server.zqevans2.workers.dev/unlocked-uncond-250k.ckpt"], | |
'sample_rate': 16000, | |
'sample_size': 65536 | |
} | |
} | |
#@markdown If you're having issues with model downloads, check this to compare the SHA: | |
check_model_SHA = True #@param{type:"boolean"} | |
def get_model_filename(diffusion_model_name): | |
model_uri = models_map[diffusion_model_name]['uri_list'][0] | |
model_filename = os.path.basename(urlparse(model_uri).path) | |
return model_filename | |
def download_model(diffusion_model_name, uri_index=0): | |
if diffusion_model_name != 'custom': | |
model_filename = get_model_filename(diffusion_model_name) | |
model_local_path = os.path.join(model_path, model_filename) | |
if os.path.exists(model_local_path) and check_model_SHA: | |
print(f'Checking {diffusion_model_name} File') | |
with open(model_local_path, "rb") as f: | |
bytes = f.read() | |
hash = hashlib.sha256(bytes).hexdigest() | |
print(f'SHA: {hash}') | |
if hash == models_map[diffusion_model_name]['sha']: | |
print(f'{diffusion_model_name} SHA matches') | |
models_map[diffusion_model_name]['downloaded'] = True | |
else: | |
print(f"{diffusion_model_name} SHA doesn't match. Will redownload it.") | |
elif os.path.exists(model_local_path) and not check_model_SHA or models_map[diffusion_model_name]['downloaded']: | |
print(f'{diffusion_model_name} already downloaded. If the file is corrupt, enable check_model_SHA.') | |
models_map[diffusion_model_name]['downloaded'] = True | |
if not models_map[diffusion_model_name]['downloaded']: | |
for model_uri in models_map[diffusion_model_name]['uri_list']: | |
with open(model_local_path, "wb") as outfile: | |
r = requests.get(model_uri) | |
outfile.write(r.content) | |
outfile.flush() | |
with open(model_local_path, "rb") as f: | |
bytes = f.read() | |
hash = hashlib.sha256(bytes).hexdigest() | |
print(f'SHA: {hash}') | |
if os.path.exists(model_local_path): | |
models_map[diffusion_model_name]['downloaded'] = True | |
return | |
else: | |
print(f'{diffusion_model_name} model download from {model_uri} failed. Will try any fallback uri.') | |
print(f'{diffusion_model_name} download failed.') | |
if args.model_name == "custom": | |
ckpt_path = custom_ckpt_path | |
else: | |
model_info = models_map[args.model_name] | |
download_model(args.model_name) | |
ckpt_path = f'{model_path}/{get_model_filename(args.model_name)}' | |
args.sample_size = model_info["sample_size"] | |
args.sample_rate = model_info["sample_rate"] | |
print("Creating the model...") | |
model = DiffusionUncond(args) | |
model.load_state_dict(torch.load(ckpt_path)["state_dict"]) | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
model = model.requires_grad_(False).to(device) | |
print("Model created") | |
# # Remove non-EMA | |
del model.diffusion | |
model_fn = model.diffusion_ema | |
torch.cuda.empty_cache() | |
gc.collect() | |
full_batches = int((args.num - (args.num % args.batch_size)) / args.batch_size) | |
final_batch_size = args.num % args.batch_size | |
def gen_batch(batch_size): | |
# Generate random noise to sample from | |
noise = torch.randn([batch_size, 2, args.sample_size]).to(device) | |
t = torch.linspace(1, 0, args.steps + 1, device=device)[:-1] | |
step_list = get_crash_schedule(t) | |
# Generate the samples from the noise | |
generated = sampling.iplms_sample(model_fn, noise, step_list, {}) | |
# Hard-clip the generated audio | |
generated = generated.clamp(-1, 1) | |
return generated | |
samples_saved = 1 | |
for batch_idx in trange(full_batches): | |
generated = gen_batch(args.batch_size) | |
for ix, gen_sample in enumerate(generated): | |
print(f'sample {samples_saved:05d}.wav') | |
torchaudio.save(f"{samples_saved:05d}.wav", gen_sample.cpu(), args.sample_rate) | |
samples_saved += 1 | |
if final_batch_size: | |
generated = gen_batch(final_batch_size) | |
for ix, gen_sample in enumerate(generated): | |
print(f'sample {samples_saved:05d}.wav') | |
torchaudio.save(f"{samples_saved:05d}.wav", gen_sample.cpu(), args.sample_rate) | |
samples_saved += 1 | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment