Skip to content

Instantly share code, notes, and snippets.

@sagelywizard
Last active January 12, 2018 17:30
Show Gist options
  • Save sagelywizard/555830c879e94dae1b6c0b613377f5cd to your computer and use it in GitHub Desktop.
Save sagelywizard/555830c879e94dae1b6c0b613377f5cd to your computer and use it in GitHub Desktop.
Lyrics sampler for few-shot-music-gen
#!/usr/bin/python3
import os
import time
import numpy as np
def parse_lyrics_file(filename):
lyrics_lines = open(filename, 'r', errors='ignore').readlines()
tokens = []
for line in lyrics_lines:
tokens = tokens + line.split()
return tokens
class Episode(object):
def __init__(self, root, data, batch_size, support_size, query_size,
max_len, word_ids, parser):
self.support = np.zeros((batch_size, support_size, max_len), dtype=np.int32)
self.query = np.zeros((batch_size, query_size, max_len), dtype=np.int32)
artists = np.random.choice(data, size=batch_size, replace=False)
for batch, artist in enumerate(artists):
directory = os.path.join(root, artist)
if not os.path.exists(directory):
raise RuntimeError('artist directory not found: %s' % directory)
songs = os.listdir(directory)
sample_size = support_size+query_size
sample = np.random.choice(songs, size=sample_size, replace=False)
support = sample[:support_size]
query = sample[support_size:]
for song_idx, song in enumerate(support):
song_path = os.path.join(root, artist, song)
for token_idx, token in enumerate(parser(song_path)[:max_len]):
self.support[batch][song_idx][token_idx] = word_ids[token]
for song_idx, song in enumerate(query):
song_path = os.path.join(root, artist, song)
for token_idx, token in enumerate(parser(song_path)[:max_len]):
self.query[batch][song_idx][token_idx] = word_ids[token]
class EpisodeSampler(object):
def __init__(self, root, split, batch_size, support_size, query_size,
max_len, split_proportions=(8, 1, 1), persist_split=True,
persist_ids=True, parser=parse_lyrics_file):
self.root = root
self.split = split
self.batch_size = batch_size
self.support_size = support_size
self.query_size = query_size
self.max_len = max_len
self.parser = parser
self.word_ids = {}
if split not in ['train', 'val', 'test']:
raise RuntimeError('unknown split: %s' % split)
if not os.path.exists(root):
raise RuntimeError('data directory not found')
word_ids_path = os.path.join(root, 'word_ids.csv')
if persist_ids and os.path.exists(word_ids_path):
for line in open(word_ids_path, 'r'):
row = line.rstrip('\n').split(',', 1)
self.word_ids[row[1]] = int(row[0])
else:
print('Parsing lyrics...')
curr_word_id = 0
for directory, _, filenames in os.walk(root):
for filename in filenames:
filepath = os.path.join(directory, filename)
if not os.path.isdir(filepath):
for word in parser(filepath):
if word not in self.word_ids:
self.word_ids[word] = curr_word_id
curr_word_id += 1
print('done')
if persist_ids:
word_ids_csv = open(word_ids_path, 'w')
for word in self.word_ids:
word_ids_csv.write('%s,%s\n' % (self.word_ids[word], word))
split_csv_path = os.path.join(root, '%s.csv' % split)
if persist_split and os.path.exists(split_csv_path):
split_csv = open(split_csv_path, 'r')
self.data = [line.strip() for line in split_csv.readlines()]
split_csv.close()
else:
dirs = []
for artist in os.listdir(root):
if os.path.isdir(os.path.join(root, artist)):
dirs.append(artist)
artists = []
skipped_count = 0
for artist in dirs:
song_count = len(os.listdir(os.path.join(root, artist)))
if song_count >= support_size + query_size:
artists.append(artist)
else:
skipped_count += 1
if skipped_count > 0:
print("%s artists don't have K+K'=%s songs. Using %s artists" % (
skipped_count, support_size + query_size, len(artists)))
train_count = int(float(split_proportions[0]) / sum(split_proportions) * len(artists))
val_count = int(float(split_proportions[1]) / sum(split_proportions) * len(artists))
np.random.shuffle(artists)
if persist_split:
train_csv = open(os.path.join(root, 'train.csv'), 'w')
val_csv = open(os.path.join(root, 'val.csv'), 'w')
test_csv = open(os.path.join(root, 'test.csv'), 'w')
train_csv.write('\n'.join(artists[:train_count]))
val_csv.write('\n'.join(artists[train_count:train_count+val_count]))
test_csv.write('\n'.join(artists[train_count+val_count:]))
train_csv.close()
val_csv.close()
test_csv.close()
if split == 'train':
self.data = artists[:train_count]
elif split == 'val':
self.data = artists[train_count:train_count+val_count]
else:
self.data = artists[train_count+val_count:]
def __len__(self):
return len(self.data)
def __repr__(self):
return 'EpisodeSampler("%s", "%s")' % (self.root, self.split)
def get_episode(self):
return Episode(
self.root,
self.data,
self.batch_size,
self.support_size,
self.query_size,
self.max_len,
self.word_ids,
self.parser
)
if __name__ == '__main__':
root = './lyrics_data'
split = 'train'
batch_size = 10
support_size = 10
query_size = 10
max_len = 100
sampler = EpisodeSampler(root, split, batch_size, support_size, query_size,
max_len)
start = time.time()
episode = sampler.get_episode()
end = time.time()
print(episode.support.shape)
print(episode.query.shape)
print('Elapsed: %s' % (end - start))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment