Skip to content

Instantly share code, notes, and snippets.

Created August 24, 2022 22:40
Show Gist options
  • Save JD-P/fc29020d13581ad8cd11608f2b3555cb to your computer and use it in GitHub Desktop.
Save JD-P/fc29020d13581ad8cd11608f2b3555cb to your computer and use it in GitHub Desktop.
# Install Instructions
# git clone
# git clone --recursive
# pip install ipywidgets==7.7.1
# cd v-diffusion-pytorch
# pip install -r requirements.txt
# cd ..
# cd sample-generator
# pip install .
# cd ..
# pip install tqdm
# pip install matplotlib
# pip install requests
# mkdir models
# pip3 install torch==1.10.1+cu113 torchvision==0.11.2+cu113 torchaudio==0.10.1+cu113 -f
# python3 --steps 200 --batch-size 32 --num 65
from contextlib import contextmanager
from copy import deepcopy
import math
from pathlib import Path
import requests
import os
import sys
import gc
from diffusion import sampling
import torch
from torch import optim, nn
from torch.nn import functional as F
from torch.utils import data
from tqdm import trange
from einops import rearrange
import torchaudio
from audio_diffusion.models import DiffusionAttnUnet1D
import numpy as np
import argparse
import random
import matplotlib.pyplot as plt
import IPython.display as ipd
from audio_diffusion.utils import Stereo, PadCrop
from glob import glob
model_path = "./models"
parser = argparse.ArgumentParser()
parser.add_argument("--model-name", type=str, default="unlocked-250k",
help='The model to download ["ravearchive-50k", "jmann-small-190k", "maestro-150k", "unlocked-250k"]')
parser.add_argument("-n", "--num", type=int, default=8,
help="The number of samples to generate during the run.")
parser.add_argument("--steps", type=int, default=100,
help="The number of diffusion steps per sample.")
parser.add_argument("-bs", "--batch-size", type=int, default=8,
help="The number of samples to generate per batch.")
args = parser.parse_args()
#@title Model code
class DiffusionUncond(nn.Module):
def __init__(self, global_args):
self.diffusion = DiffusionAttnUnet1D(global_args, n_attn_layers = 4)
self.diffusion_ema = deepcopy(self.diffusion)
self.rng = torch.quasirandom.SobolEngine(1, scramble=True)
import matplotlib.pyplot as plt
import IPython.display as ipd
def load_to_device(path, sr):
audio, file_sr = torchaudio.load(path)
if sr != file_sr:
audio = torchaudio.transforms.Resample(file_sr, sr)(audio)
audio =
return audio
def get_alphas_sigmas(t):
"""Returns the scaling factors for the clean image (alpha) and for the
noise (sigma), given a timestep."""
return torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2)
def get_crash_schedule(t):
sigma = torch.sin(t * math.pi / 2) ** 2
alpha = (1 - sigma ** 2) ** 0.5
return alpha_sigma_to_t(alpha, sigma)
def t_to_alpha_sigma(t):
"""Returns the scaling factors for the clean image and for the noise, given
a timestep."""
return torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2)
def alpha_sigma_to_t(alpha, sigma):
"""Returns a timestep, given the scaling factors for the clean image and for
the noise."""
return torch.atan2(sigma, alpha) / math.pi * 2
#@title Args
sample_size = 65536
sample_rate = 48000
latent_dim = 0
args.sample_size = sample_size
args.sample_rate = sample_rate
args.latent_dim = latent_dim
from urllib.parse import urlparse
import hashlib
#@title Create the model
#@markdown If you have a custom fine-tuned model, choose "custom" above and enter a path to the model checkpoint here here
custom_ckpt_path = ''#@param {type: 'string'}
models_map = {
"ravearchive-50k": {'downloaded': False,
'sha': "4cd36b7071110c339649a118ca3ce9c5ac6d70e263f21ea9870d239eff5cb5e4",
'uri_list': [""],
'sample_rate': 48000,
'sample_size': 65536
"jmann-small-190k": {'downloaded': False,
'sha': "1e2a23a54e960b80227303d0495247a744fa1296652148da18a4da17c3784e9b",
'uri_list': [""],
'sample_rate': 48000,
'sample_size': 65536
"maestro-150k": {'downloaded': False,
'sha': "49d9abcae642e47c2082cec0b2dce95a45dc6e961805b6500204e27122d09485",
'uri_list': [""],
'sample_rate': 16000,
'sample_size': 65536
"unlocked-250k": {'downloaded': False,
'sha': "af337c8416732216eeb52db31dcc0d49a8d48e2b3ecaa524cb854c36b5a3503a",
'uri_list': [""],
'sample_rate': 16000,
'sample_size': 65536
#@markdown If you're having issues with model downloads, check this to compare the SHA:
check_model_SHA = True #@param{type:"boolean"}
def get_model_filename(diffusion_model_name):
model_uri = models_map[diffusion_model_name]['uri_list'][0]
model_filename = os.path.basename(urlparse(model_uri).path)
return model_filename
def download_model(diffusion_model_name, uri_index=0):
if diffusion_model_name != 'custom':
model_filename = get_model_filename(diffusion_model_name)
model_local_path = os.path.join(model_path, model_filename)
if os.path.exists(model_local_path) and check_model_SHA:
print(f'Checking {diffusion_model_name} File')
with open(model_local_path, "rb") as f:
bytes =
hash = hashlib.sha256(bytes).hexdigest()
print(f'SHA: {hash}')
if hash == models_map[diffusion_model_name]['sha']:
print(f'{diffusion_model_name} SHA matches')
models_map[diffusion_model_name]['downloaded'] = True
print(f"{diffusion_model_name} SHA doesn't match. Will redownload it.")
elif os.path.exists(model_local_path) and not check_model_SHA or models_map[diffusion_model_name]['downloaded']:
print(f'{diffusion_model_name} already downloaded. If the file is corrupt, enable check_model_SHA.')
models_map[diffusion_model_name]['downloaded'] = True
if not models_map[diffusion_model_name]['downloaded']:
for model_uri in models_map[diffusion_model_name]['uri_list']:
with open(model_local_path, "wb") as outfile:
r = requests.get(model_uri)
with open(model_local_path, "rb") as f:
bytes =
hash = hashlib.sha256(bytes).hexdigest()
print(f'SHA: {hash}')
if os.path.exists(model_local_path):
models_map[diffusion_model_name]['downloaded'] = True
print(f'{diffusion_model_name} model download from {model_uri} failed. Will try any fallback uri.')
print(f'{diffusion_model_name} download failed.')
if args.model_name == "custom":
ckpt_path = custom_ckpt_path
model_info = models_map[args.model_name]
ckpt_path = f'{model_path}/{get_model_filename(args.model_name)}'
args.sample_size = model_info["sample_size"]
args.sample_rate = model_info["sample_rate"]
print("Creating the model...")
model = DiffusionUncond(args)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.requires_grad_(False).to(device)
print("Model created")
# # Remove non-EMA
del model.diffusion
model_fn = model.diffusion_ema
full_batches = int((args.num - (args.num % args.batch_size)) / args.batch_size)
final_batch_size = args.num % args.batch_size
def gen_batch(batch_size):
# Generate random noise to sample from
noise = torch.randn([batch_size, 2, args.sample_size]).to(device)
t = torch.linspace(1, 0, args.steps + 1, device=device)[:-1]
step_list = get_crash_schedule(t)
# Generate the samples from the noise
generated = sampling.iplms_sample(model_fn, noise, step_list, {})
# Hard-clip the generated audio
generated = generated.clamp(-1, 1)
return generated
samples_saved = 1
for batch_idx in trange(full_batches):
generated = gen_batch(args.batch_size)
for ix, gen_sample in enumerate(generated):
print(f'sample {samples_saved:05d}.wav')"{samples_saved:05d}.wav", gen_sample.cpu(), args.sample_rate)
samples_saved += 1
if final_batch_size:
generated = gen_batch(final_batch_size)
for ix, gen_sample in enumerate(generated):
print(f'sample {samples_saved:05d}.wav')"{samples_saved:05d}.wav", gen_sample.cpu(), args.sample_rate)
samples_saved += 1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment