Skip to content

Instantly share code, notes, and snippets.

@chiroptical
Created April 23, 2020 14:00
Show Gist options
  • Select an option

  • Save chiroptical/7c1af58170960328973a6cb265d897f0 to your computer and use it in GitHub Desktop.

Select an option

Save chiroptical/7c1af58170960328973a6cb265d897f0 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python
""" 0-segment-audio.ray.py -- Using a rolling window to split audio files into segments with a specific duration
Usage:
0-segment-audio.ray.py [-hv] (-i <directory>) (-o <directory>) (-d <duration>) (-p <overlap>)
[-a -l <labels.csv>] (-r <address>) (-s <password>) (-c <cores_per_node>) [-b <batch_size>]
Positional Arguments:
Options:
-h --help Print this screen and exit
-v --version Print the version of 0-segment-audio.ray.py
-i --input_directory <directory> The input directory to search for audio files (`.{wav,WAV}`)
-o --output_directory <directory> The output directory for segments
-d --duration <duration> The segment duration in seconds
-p --overlap <overlap> Overlap of each segment in seconds
-a --annotations Only include segments which overlap with Raven annotations
-> Audio files without annotations are skipped when using this argument
-l --labels <labels.csv> When using `--annotations`, the `labels.csv` file provides a map from an
incorrect label (column `from`) to the corrected label (column `to`)
-r --ray_address <address> The ray cluster address
-s --ray_password <password> The ray cluster password
-c --cores_per_node <cores> Number of cores per node
-b --batch_size <batch_size> The batch size [default: 1]
"""
def check_is_integer(x, parameter):
try:
r = int(x)
if r <= 0:
raise ValueError
return r
except ValueError:
exit(f"Error: `{parameter}` should be a positive whole number! Got `{x}`")
from docopt import docopt
import numpy as np
import ray
from pathlib import Path
args = docopt(__doc__, version="0_segment_audio.ray.py version 0.0.1")
args["--duration"] = check_is_integer(args["--duration"], "--duration")
args["--overlap"] = check_is_integer(args["--overlap"], "--overlap")
args["--cores_per_node"] = check_is_integer(
args["--cores_per_node"], "--cores_per_node"
)
if args["--batch_size"]:
args["--batch_size"] = check_is_integer(args["--batch_size"], "--batch_size")
else:
args["--batch_size"] = 1
try:
ray.init(address=args["--ray_address"], redis_password=args["--ray_password"])
except:
exit("Error: couldn't connect to Ray cluster")
num_nodes = len(ray.nodes())
input_p = Path(args["--input_directory"])
output_p = Path(args["--output_directory"])
all_wavs = list(input_p.rglob("**/*.WAV"))
chunks = np.array_split(all_wavs, num_nodes)
@ray.remote(num_cpus=args["--cores_per_node"])
def run_splitter(idx, chunk, args):
from splitter import Splitter
import torch
dataset = Splitter(
chunk,
annotations=args["--annotations"],
labels=args["--labels"],
overlap=args["--overlap"],
duration=args["--duration"],
output_directory=args["--output_directory"],
)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=args["--batch_size"],
shuffle=False,
num_workers=args["--cores_per_node"],
collate_fn=dataset.collate_fn,
)
with open(f"{args['--output_directory']}/segments-{idx}.csv", "w"):
for idx, data in enumerate(dataloader):
for output in data:
f.write(f"{output}\n")
return True
chunk_ids = [run_splitter.remote(idx, chunk, args) for idx, chunk in enumerate(chunks)]
_ = [ray.get(chunk_id) for chunk_id in chunk_ids]
segments = output_p.rglob("segments-*.csv")
with open("segments.csv", "w") as f:
if args["--annotations"]:
f.write("Source,Annotations,Begin (s),End (s),Destination,Labels\n")
else:
f.write("Source,Begin (s),End (s),Destination\n")
for segment in segments:
with open(segment, "r") as g:
for line in g:
f.write(line)

In 0_segment_audio.ray.py the goal is to use Ray to chunk the work over multiple nodes. On line 69, you see I try to set the number of CPUs explictly to control this. The behavior I experience is that I get 2 processes running the work (as expected) but they reside on the same node. Additionally, setting num_workers on line 87 appears to have no effect.

"""
The `Splitter` class describes a torch `DataSet` that splits audio files, using optional annotations, into overlapping sections of a certain duration
"""
import torch
import pandas as pd
from math import ceil, floor
from hashlib import md5
from librosa.output import write_wav
from librosa.core import load, get_duration
import sys
from pathlib import Path
import numpy as np
def get_segment(clip_begin, clip_end, samples, sr):
begin = floor(clip_begin * sr)
end = ceil(clip_end * sr)
return samples[begin:end], begin, end
def get_md5_digest(s):
obj = md5()
obj.update(s.encode("utf-8"))
return obj.hexdigest()
def annotations_with_overlaps_with_clip(df, begin, end):
return df[
((df["begin time (s)"] >= begin) & (df["begin time (s)"] < end))
| ((df["end time (s)"] > begin) & (df["end time (s)"] <= end))
]
class Splitter(torch.utils.data.Dataset):
def __init__(
self,
wavs,
annotations=None,
labels=None,
overlap=1,
duration=5,
output_directory="segments",
):
self.wavs = list(wavs)
self.annotations = annotations
self.labels = labels
if self.labels:
self.labels_df = pd.read_csv(labels)
self.overlap = overlap
self.duration = duration
self.output_directory = output_directory
def __len__(self):
return len(self.wavs)
def __getitem__(self, item_idx):
wav = self.wavs[item_idx]
annotation_prefix = self.wavs[item_idx].stem.split(".")[0]
if self.annotations:
annotation_file = Path(
f"{wav.parent}/{annotation_prefix}.Table.1.selections.txt.lower"
)
if not annotation_file.is_file():
sys.stderr.write(f"Warning: Found no Raven annotations for {wav}\n")
return {"data": []}
# TODO: Need to feed audio related configurations to `load`
wav_samples, wav_sample_rate = load(wav)
wav_duration = get_duration(wav_samples, sr=wav_sample_rate)
wav_times = np.arange(0.0, wav_duration, wav_duration / len(wav_samples))
if self.annotations:
annotation_df = pd.read_csv(annotation_file, sep="\t").sort_values(
by=["begin time (s)"]
)
if self.labels:
annotation_df["class"] = annotation_df["class"].fillna("unknown")
annotation_df["class"] = annotation_df["class"].apply(
lambda cls: self.labels_df[self.labels_df["from"] == cls]["to"].values[
0
]
)
num_segments = ceil(
(wav_duration - self.overlap) / (self.duration - self.overlap)
)
outputs = []
for idx in range(num_segments):
if idx == num_segments - 1:
end = wav_duration
begin = end - self.duration
else:
begin = self.duration * idx - self.overlap * idx
end = begin + self.duration
if self.annotations:
overlaps = annotations_with_overlaps_with_clip(
annotation_df, begin, end
)
unique_string = f"{wav}-{begin}-{end}"
destination = f"{self.output_directory}/{get_md5_digest(unique_string)}"
if self.annotations:
if overlaps.shape[0] > 0:
segment_samples, segment_sample_begin, segment_sample_end = get_segment(
begin, end, wav_samples, wav_sample_rate
)
write_wav(f"{destination}.WAV", segment_samples, wav_sample_rate)
if idx == num_segments - 1:
to_append = f"{wav},{annotation_file},{wav_times[segment_sample_begin]},{wav_times[-1]},{destination}.WAV"
else:
to_append = f"{wav},{annotation_file},{wav_times[segment_sample_begin]},{wav_times[segment_sample_end]},{destination}.WAV"
to_append += f",{'|'.join(overlaps['class'].unique())}"
outputs.append(to_append)
else:
segment_samples, segment_sample_begin, segment_sample_end = get_segment(
begin, end, wav_samples, wav_sample_rate
)
write_wav(f"{destination}.WAV", segment_samples, wav_sample_rate)
if idx == num_segments - 1:
to_append = f"{wav},{wav_times[segment_sample_begin]},{wav_times[-1]},{destination}.WAV"
else:
to_append = f"{wav},{wav_times[segment_sample_begin]},{wav_times[segment_sample_end]},{destination}.WAV"
outputs.append(to_append)
return {"data": outputs}
def collate_fn(batch):
return chain.from_iterable([x["data"] for x in batch])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment