Last active
January 12, 2018 17:30
-
-
Save sagelywizard/555830c879e94dae1b6c0b613377f5cd to your computer and use it in GitHub Desktop.
Lyrics sampler for few-shot-music-gen
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/python3 | |
import os | |
import time | |
import numpy as np | |
def parse_lyrics_file(filename): | |
lyrics_lines = open(filename, 'r', errors='ignore').readlines() | |
tokens = [] | |
for line in lyrics_lines: | |
tokens = tokens + line.split() | |
return tokens | |
class Episode(object): | |
def __init__(self, root, data, batch_size, support_size, query_size, | |
max_len, word_ids, parser): | |
self.support = np.zeros((batch_size, support_size, max_len), dtype=np.int32) | |
self.query = np.zeros((batch_size, query_size, max_len), dtype=np.int32) | |
artists = np.random.choice(data, size=batch_size, replace=False) | |
for batch, artist in enumerate(artists): | |
directory = os.path.join(root, artist) | |
if not os.path.exists(directory): | |
raise RuntimeError('artist directory not found: %s' % directory) | |
songs = os.listdir(directory) | |
sample_size = support_size+query_size | |
sample = np.random.choice(songs, size=sample_size, replace=False) | |
support = sample[:support_size] | |
query = sample[support_size:] | |
for song_idx, song in enumerate(support): | |
song_path = os.path.join(root, artist, song) | |
for token_idx, token in enumerate(parser(song_path)[:max_len]): | |
self.support[batch][song_idx][token_idx] = word_ids[token] | |
for song_idx, song in enumerate(query): | |
song_path = os.path.join(root, artist, song) | |
for token_idx, token in enumerate(parser(song_path)[:max_len]): | |
self.query[batch][song_idx][token_idx] = word_ids[token] | |
class EpisodeSampler(object): | |
def __init__(self, root, split, batch_size, support_size, query_size, | |
max_len, split_proportions=(8, 1, 1), persist_split=True, | |
persist_ids=True, parser=parse_lyrics_file): | |
self.root = root | |
self.split = split | |
self.batch_size = batch_size | |
self.support_size = support_size | |
self.query_size = query_size | |
self.max_len = max_len | |
self.parser = parser | |
self.word_ids = {} | |
if split not in ['train', 'val', 'test']: | |
raise RuntimeError('unknown split: %s' % split) | |
if not os.path.exists(root): | |
raise RuntimeError('data directory not found') | |
word_ids_path = os.path.join(root, 'word_ids.csv') | |
if persist_ids and os.path.exists(word_ids_path): | |
for line in open(word_ids_path, 'r'): | |
row = line.rstrip('\n').split(',', 1) | |
self.word_ids[row[1]] = int(row[0]) | |
else: | |
print('Parsing lyrics...') | |
curr_word_id = 0 | |
for directory, _, filenames in os.walk(root): | |
for filename in filenames: | |
filepath = os.path.join(directory, filename) | |
if not os.path.isdir(filepath): | |
for word in parser(filepath): | |
if word not in self.word_ids: | |
self.word_ids[word] = curr_word_id | |
curr_word_id += 1 | |
print('done') | |
if persist_ids: | |
word_ids_csv = open(word_ids_path, 'w') | |
for word in self.word_ids: | |
word_ids_csv.write('%s,%s\n' % (self.word_ids[word], word)) | |
split_csv_path = os.path.join(root, '%s.csv' % split) | |
if persist_split and os.path.exists(split_csv_path): | |
split_csv = open(split_csv_path, 'r') | |
self.data = [line.strip() for line in split_csv.readlines()] | |
split_csv.close() | |
else: | |
dirs = [] | |
for artist in os.listdir(root): | |
if os.path.isdir(os.path.join(root, artist)): | |
dirs.append(artist) | |
artists = [] | |
skipped_count = 0 | |
for artist in dirs: | |
song_count = len(os.listdir(os.path.join(root, artist))) | |
if song_count >= support_size + query_size: | |
artists.append(artist) | |
else: | |
skipped_count += 1 | |
if skipped_count > 0: | |
print("%s artists don't have K+K'=%s songs. Using %s artists" % ( | |
skipped_count, support_size + query_size, len(artists))) | |
train_count = int(float(split_proportions[0]) / sum(split_proportions) * len(artists)) | |
val_count = int(float(split_proportions[1]) / sum(split_proportions) * len(artists)) | |
np.random.shuffle(artists) | |
if persist_split: | |
train_csv = open(os.path.join(root, 'train.csv'), 'w') | |
val_csv = open(os.path.join(root, 'val.csv'), 'w') | |
test_csv = open(os.path.join(root, 'test.csv'), 'w') | |
train_csv.write('\n'.join(artists[:train_count])) | |
val_csv.write('\n'.join(artists[train_count:train_count+val_count])) | |
test_csv.write('\n'.join(artists[train_count+val_count:])) | |
train_csv.close() | |
val_csv.close() | |
test_csv.close() | |
if split == 'train': | |
self.data = artists[:train_count] | |
elif split == 'val': | |
self.data = artists[train_count:train_count+val_count] | |
else: | |
self.data = artists[train_count+val_count:] | |
def __len__(self): | |
return len(self.data) | |
def __repr__(self): | |
return 'EpisodeSampler("%s", "%s")' % (self.root, self.split) | |
def get_episode(self): | |
return Episode( | |
self.root, | |
self.data, | |
self.batch_size, | |
self.support_size, | |
self.query_size, | |
self.max_len, | |
self.word_ids, | |
self.parser | |
) | |
if __name__ == '__main__': | |
root = './lyrics_data' | |
split = 'train' | |
batch_size = 10 | |
support_size = 10 | |
query_size = 10 | |
max_len = 100 | |
sampler = EpisodeSampler(root, split, batch_size, support_size, query_size, | |
max_len) | |
start = time.time() | |
episode = sampler.get_episode() | |
end = time.time() | |
print(episode.support.shape) | |
print(episode.query.shape) | |
print('Elapsed: %s' % (end - start)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment