Skip to content

Instantly share code, notes, and snippets.

@JD-P
Created August 24, 2022 22:40
Show Gist options
  • Save JD-P/fc29020d13581ad8cd11608f2b3555cb to your computer and use it in GitHub Desktop.
Save JD-P/fc29020d13581ad8cd11608f2b3555cb to your computer and use it in GitHub Desktop.
# Install Instructions
# git clone https://github.com/harmonai-org/sample-generator
# git clone --recursive https://github.com/crowsonkb/v-diffusion-pytorch
# pip install ipywidgets==7.7.1
# cd v-diffusion-pytorch
# pip install -r requirements.txt
# cd ..
# cd sample-generator
# pip install .
# cd ..
# pip install tqdm
# pip install matplotlib
# pip install requests
# mkdir models
# pip3 install torch==1.10.1+cu113 torchvision==0.11.2+cu113 torchaudio==0.10.1+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html
# python3 dance_diffusion.py --steps 200 --batch-size 32 --num 65
from contextlib import contextmanager
from copy import deepcopy
import math
from pathlib import Path
import requests
import os
import sys
import gc
sys.path.append("v-diffusion-pytorch")
from diffusion import sampling
import torch
from torch import optim, nn
from torch.nn import functional as F
from torch.utils import data
from tqdm import trange
from einops import rearrange
import torchaudio
from audio_diffusion.models import DiffusionAttnUnet1D
import numpy as np
import argparse
import random
import matplotlib.pyplot as plt
import IPython.display as ipd
from audio_diffusion.utils import Stereo, PadCrop
from glob import glob
model_path = "./models"
parser = argparse.ArgumentParser()
parser.add_argument("--model-name", type=str, default="unlocked-250k",
help='The model to download ["ravearchive-50k", "jmann-small-190k", "maestro-150k", "unlocked-250k"]')
parser.add_argument("-n", "--num", type=int, default=8,
help="The number of samples to generate during the run.")
parser.add_argument("--steps", type=int, default=100,
help="The number of diffusion steps per sample.")
parser.add_argument("-bs", "--batch-size", type=int, default=8,
help="The number of samples to generate per batch.")
args = parser.parse_args()
#@title Model code
class DiffusionUncond(nn.Module):
def __init__(self, global_args):
super().__init__()
self.diffusion = DiffusionAttnUnet1D(global_args, n_attn_layers = 4)
self.diffusion_ema = deepcopy(self.diffusion)
self.rng = torch.quasirandom.SobolEngine(1, scramble=True)
import matplotlib.pyplot as plt
import IPython.display as ipd
def load_to_device(path, sr):
audio, file_sr = torchaudio.load(path)
if sr != file_sr:
audio = torchaudio.transforms.Resample(file_sr, sr)(audio)
audio = audio.to(device)
return audio
def get_alphas_sigmas(t):
"""Returns the scaling factors for the clean image (alpha) and for the
noise (sigma), given a timestep."""
return torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2)
def get_crash_schedule(t):
sigma = torch.sin(t * math.pi / 2) ** 2
alpha = (1 - sigma ** 2) ** 0.5
return alpha_sigma_to_t(alpha, sigma)
def t_to_alpha_sigma(t):
"""Returns the scaling factors for the clean image and for the noise, given
a timestep."""
return torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2)
def alpha_sigma_to_t(alpha, sigma):
"""Returns a timestep, given the scaling factors for the clean image and for
the noise."""
return torch.atan2(sigma, alpha) / math.pi * 2
#@title Args
sample_size = 65536
sample_rate = 48000
latent_dim = 0
args.sample_size = sample_size
args.sample_rate = sample_rate
args.latent_dim = latent_dim
from urllib.parse import urlparse
import hashlib
#@title Create the model
#@markdown If you have a custom fine-tuned model, choose "custom" above and enter a path to the model checkpoint here here
custom_ckpt_path = ''#@param {type: 'string'}
models_map = {
"ravearchive-50k": {'downloaded': False,
'sha': "4cd36b7071110c339649a118ca3ce9c5ac6d70e263f21ea9870d239eff5cb5e4",
'uri_list': ["https://model-server.zqevans2.workers.dev/ravearchive-uncond-50k.ckpt"],
'sample_rate': 48000,
'sample_size': 65536
},
"jmann-small-190k": {'downloaded': False,
'sha': "1e2a23a54e960b80227303d0495247a744fa1296652148da18a4da17c3784e9b",
'uri_list': ["https://model-server.zqevans2.workers.dev/jmann-small-190k.ckpt"],
'sample_rate': 48000,
'sample_size': 65536
},
"maestro-150k": {'downloaded': False,
'sha': "49d9abcae642e47c2082cec0b2dce95a45dc6e961805b6500204e27122d09485",
'uri_list': ["https://model-server.zqevans2.workers.dev/maestro-uncond-150k.ckpt"],
'sample_rate': 16000,
'sample_size': 65536
},
"unlocked-250k": {'downloaded': False,
'sha': "af337c8416732216eeb52db31dcc0d49a8d48e2b3ecaa524cb854c36b5a3503a",
'uri_list': ["https://model-server.zqevans2.workers.dev/unlocked-uncond-250k.ckpt"],
'sample_rate': 16000,
'sample_size': 65536
}
}
#@markdown If you're having issues with model downloads, check this to compare the SHA:
check_model_SHA = True #@param{type:"boolean"}
def get_model_filename(diffusion_model_name):
model_uri = models_map[diffusion_model_name]['uri_list'][0]
model_filename = os.path.basename(urlparse(model_uri).path)
return model_filename
def download_model(diffusion_model_name, uri_index=0):
if diffusion_model_name != 'custom':
model_filename = get_model_filename(diffusion_model_name)
model_local_path = os.path.join(model_path, model_filename)
if os.path.exists(model_local_path) and check_model_SHA:
print(f'Checking {diffusion_model_name} File')
with open(model_local_path, "rb") as f:
bytes = f.read()
hash = hashlib.sha256(bytes).hexdigest()
print(f'SHA: {hash}')
if hash == models_map[diffusion_model_name]['sha']:
print(f'{diffusion_model_name} SHA matches')
models_map[diffusion_model_name]['downloaded'] = True
else:
print(f"{diffusion_model_name} SHA doesn't match. Will redownload it.")
elif os.path.exists(model_local_path) and not check_model_SHA or models_map[diffusion_model_name]['downloaded']:
print(f'{diffusion_model_name} already downloaded. If the file is corrupt, enable check_model_SHA.')
models_map[diffusion_model_name]['downloaded'] = True
if not models_map[diffusion_model_name]['downloaded']:
for model_uri in models_map[diffusion_model_name]['uri_list']:
with open(model_local_path, "wb") as outfile:
r = requests.get(model_uri)
outfile.write(r.content)
outfile.flush()
with open(model_local_path, "rb") as f:
bytes = f.read()
hash = hashlib.sha256(bytes).hexdigest()
print(f'SHA: {hash}')
if os.path.exists(model_local_path):
models_map[diffusion_model_name]['downloaded'] = True
return
else:
print(f'{diffusion_model_name} model download from {model_uri} failed. Will try any fallback uri.')
print(f'{diffusion_model_name} download failed.')
if args.model_name == "custom":
ckpt_path = custom_ckpt_path
else:
model_info = models_map[args.model_name]
download_model(args.model_name)
ckpt_path = f'{model_path}/{get_model_filename(args.model_name)}'
args.sample_size = model_info["sample_size"]
args.sample_rate = model_info["sample_rate"]
print("Creating the model...")
model = DiffusionUncond(args)
model.load_state_dict(torch.load(ckpt_path)["state_dict"])
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.requires_grad_(False).to(device)
print("Model created")
# # Remove non-EMA
del model.diffusion
model_fn = model.diffusion_ema
torch.cuda.empty_cache()
gc.collect()
full_batches = int((args.num - (args.num % args.batch_size)) / args.batch_size)
final_batch_size = args.num % args.batch_size
def gen_batch(batch_size):
# Generate random noise to sample from
noise = torch.randn([batch_size, 2, args.sample_size]).to(device)
t = torch.linspace(1, 0, args.steps + 1, device=device)[:-1]
step_list = get_crash_schedule(t)
# Generate the samples from the noise
generated = sampling.iplms_sample(model_fn, noise, step_list, {})
# Hard-clip the generated audio
generated = generated.clamp(-1, 1)
return generated
samples_saved = 1
for batch_idx in trange(full_batches):
generated = gen_batch(args.batch_size)
for ix, gen_sample in enumerate(generated):
print(f'sample {samples_saved:05d}.wav')
torchaudio.save(f"{samples_saved:05d}.wav", gen_sample.cpu(), args.sample_rate)
samples_saved += 1
if final_batch_size:
generated = gen_batch(final_batch_size)
for ix, gen_sample in enumerate(generated):
print(f'sample {samples_saved:05d}.wav')
torchaudio.save(f"{samples_saved:05d}.wav", gen_sample.cpu(), args.sample_rate)
samples_saved += 1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment