In 0_segment_audio.ray.py the goal is to use Ray to chunk the work over multiple nodes.
On line 69, you see I try to set the number of CPUs explictly to control this. The behavior
I experience is that I get 2 processes running the work (as expected) but they reside on the same node.
Additionally, setting num_workers on line 87 appears to have no effect.
Created
April 23, 2020 14:00
-
-
Save chiroptical/7c1af58170960328973a6cb265d897f0 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python | |
| """ 0-segment-audio.ray.py -- Using a rolling window to split audio files into segments with a specific duration | |
| Usage: | |
| 0-segment-audio.ray.py [-hv] (-i <directory>) (-o <directory>) (-d <duration>) (-p <overlap>) | |
| [-a -l <labels.csv>] (-r <address>) (-s <password>) (-c <cores_per_node>) [-b <batch_size>] | |
| Positional Arguments: | |
| Options: | |
| -h --help Print this screen and exit | |
| -v --version Print the version of 0-segment-audio.ray.py | |
| -i --input_directory <directory> The input directory to search for audio files (`.{wav,WAV}`) | |
| -o --output_directory <directory> The output directory for segments | |
| -d --duration <duration> The segment duration in seconds | |
| -p --overlap <overlap> Overlap of each segment in seconds | |
| -a --annotations Only include segments which overlap with Raven annotations | |
| -> Audio files without annotations are skipped when using this argument | |
| -l --labels <labels.csv> When using `--annotations`, the `labels.csv` file provides a map from an | |
| incorrect label (column `from`) to the corrected label (column `to`) | |
| -r --ray_address <address> The ray cluster address | |
| -s --ray_password <password> The ray cluster password | |
| -c --cores_per_node <cores> Number of cores per node | |
| -b --batch_size <batch_size> The batch size [default: 1] | |
| """ | |
| def check_is_integer(x, parameter): | |
| try: | |
| r = int(x) | |
| if r <= 0: | |
| raise ValueError | |
| return r | |
| except ValueError: | |
| exit(f"Error: `{parameter}` should be a positive whole number! Got `{x}`") | |
| from docopt import docopt | |
| import numpy as np | |
| import ray | |
| from pathlib import Path | |
| args = docopt(__doc__, version="0_segment_audio.ray.py version 0.0.1") | |
| args["--duration"] = check_is_integer(args["--duration"], "--duration") | |
| args["--overlap"] = check_is_integer(args["--overlap"], "--overlap") | |
| args["--cores_per_node"] = check_is_integer( | |
| args["--cores_per_node"], "--cores_per_node" | |
| ) | |
| if args["--batch_size"]: | |
| args["--batch_size"] = check_is_integer(args["--batch_size"], "--batch_size") | |
| else: | |
| args["--batch_size"] = 1 | |
| try: | |
| ray.init(address=args["--ray_address"], redis_password=args["--ray_password"]) | |
| except: | |
| exit("Error: couldn't connect to Ray cluster") | |
| num_nodes = len(ray.nodes()) | |
| input_p = Path(args["--input_directory"]) | |
| output_p = Path(args["--output_directory"]) | |
| all_wavs = list(input_p.rglob("**/*.WAV")) | |
| chunks = np.array_split(all_wavs, num_nodes) | |
| @ray.remote(num_cpus=args["--cores_per_node"]) | |
| def run_splitter(idx, chunk, args): | |
| from splitter import Splitter | |
| import torch | |
| dataset = Splitter( | |
| chunk, | |
| annotations=args["--annotations"], | |
| labels=args["--labels"], | |
| overlap=args["--overlap"], | |
| duration=args["--duration"], | |
| output_directory=args["--output_directory"], | |
| ) | |
| dataloader = torch.utils.data.DataLoader( | |
| dataset, | |
| batch_size=args["--batch_size"], | |
| shuffle=False, | |
| num_workers=args["--cores_per_node"], | |
| collate_fn=dataset.collate_fn, | |
| ) | |
| with open(f"{args['--output_directory']}/segments-{idx}.csv", "w"): | |
| for idx, data in enumerate(dataloader): | |
| for output in data: | |
| f.write(f"{output}\n") | |
| return True | |
| chunk_ids = [run_splitter.remote(idx, chunk, args) for idx, chunk in enumerate(chunks)] | |
| _ = [ray.get(chunk_id) for chunk_id in chunk_ids] | |
| segments = output_p.rglob("segments-*.csv") | |
| with open("segments.csv", "w") as f: | |
| if args["--annotations"]: | |
| f.write("Source,Annotations,Begin (s),End (s),Destination,Labels\n") | |
| else: | |
| f.write("Source,Begin (s),End (s),Destination\n") | |
| for segment in segments: | |
| with open(segment, "r") as g: | |
| for line in g: | |
| f.write(line) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| The `Splitter` class describes a torch `DataSet` that splits audio files, using optional annotations, into overlapping sections of a certain duration | |
| """ | |
| import torch | |
| import pandas as pd | |
| from math import ceil, floor | |
| from hashlib import md5 | |
| from librosa.output import write_wav | |
| from librosa.core import load, get_duration | |
| import sys | |
| from pathlib import Path | |
| import numpy as np | |
| def get_segment(clip_begin, clip_end, samples, sr): | |
| begin = floor(clip_begin * sr) | |
| end = ceil(clip_end * sr) | |
| return samples[begin:end], begin, end | |
| def get_md5_digest(s): | |
| obj = md5() | |
| obj.update(s.encode("utf-8")) | |
| return obj.hexdigest() | |
| def annotations_with_overlaps_with_clip(df, begin, end): | |
| return df[ | |
| ((df["begin time (s)"] >= begin) & (df["begin time (s)"] < end)) | |
| | ((df["end time (s)"] > begin) & (df["end time (s)"] <= end)) | |
| ] | |
| class Splitter(torch.utils.data.Dataset): | |
| def __init__( | |
| self, | |
| wavs, | |
| annotations=None, | |
| labels=None, | |
| overlap=1, | |
| duration=5, | |
| output_directory="segments", | |
| ): | |
| self.wavs = list(wavs) | |
| self.annotations = annotations | |
| self.labels = labels | |
| if self.labels: | |
| self.labels_df = pd.read_csv(labels) | |
| self.overlap = overlap | |
| self.duration = duration | |
| self.output_directory = output_directory | |
| def __len__(self): | |
| return len(self.wavs) | |
| def __getitem__(self, item_idx): | |
| wav = self.wavs[item_idx] | |
| annotation_prefix = self.wavs[item_idx].stem.split(".")[0] | |
| if self.annotations: | |
| annotation_file = Path( | |
| f"{wav.parent}/{annotation_prefix}.Table.1.selections.txt.lower" | |
| ) | |
| if not annotation_file.is_file(): | |
| sys.stderr.write(f"Warning: Found no Raven annotations for {wav}\n") | |
| return {"data": []} | |
| # TODO: Need to feed audio related configurations to `load` | |
| wav_samples, wav_sample_rate = load(wav) | |
| wav_duration = get_duration(wav_samples, sr=wav_sample_rate) | |
| wav_times = np.arange(0.0, wav_duration, wav_duration / len(wav_samples)) | |
| if self.annotations: | |
| annotation_df = pd.read_csv(annotation_file, sep="\t").sort_values( | |
| by=["begin time (s)"] | |
| ) | |
| if self.labels: | |
| annotation_df["class"] = annotation_df["class"].fillna("unknown") | |
| annotation_df["class"] = annotation_df["class"].apply( | |
| lambda cls: self.labels_df[self.labels_df["from"] == cls]["to"].values[ | |
| 0 | |
| ] | |
| ) | |
| num_segments = ceil( | |
| (wav_duration - self.overlap) / (self.duration - self.overlap) | |
| ) | |
| outputs = [] | |
| for idx in range(num_segments): | |
| if idx == num_segments - 1: | |
| end = wav_duration | |
| begin = end - self.duration | |
| else: | |
| begin = self.duration * idx - self.overlap * idx | |
| end = begin + self.duration | |
| if self.annotations: | |
| overlaps = annotations_with_overlaps_with_clip( | |
| annotation_df, begin, end | |
| ) | |
| unique_string = f"{wav}-{begin}-{end}" | |
| destination = f"{self.output_directory}/{get_md5_digest(unique_string)}" | |
| if self.annotations: | |
| if overlaps.shape[0] > 0: | |
| segment_samples, segment_sample_begin, segment_sample_end = get_segment( | |
| begin, end, wav_samples, wav_sample_rate | |
| ) | |
| write_wav(f"{destination}.WAV", segment_samples, wav_sample_rate) | |
| if idx == num_segments - 1: | |
| to_append = f"{wav},{annotation_file},{wav_times[segment_sample_begin]},{wav_times[-1]},{destination}.WAV" | |
| else: | |
| to_append = f"{wav},{annotation_file},{wav_times[segment_sample_begin]},{wav_times[segment_sample_end]},{destination}.WAV" | |
| to_append += f",{'|'.join(overlaps['class'].unique())}" | |
| outputs.append(to_append) | |
| else: | |
| segment_samples, segment_sample_begin, segment_sample_end = get_segment( | |
| begin, end, wav_samples, wav_sample_rate | |
| ) | |
| write_wav(f"{destination}.WAV", segment_samples, wav_sample_rate) | |
| if idx == num_segments - 1: | |
| to_append = f"{wav},{wav_times[segment_sample_begin]},{wav_times[-1]},{destination}.WAV" | |
| else: | |
| to_append = f"{wav},{wav_times[segment_sample_begin]},{wav_times[segment_sample_end]},{destination}.WAV" | |
| outputs.append(to_append) | |
| return {"data": outputs} | |
| def collate_fn(batch): | |
| return chain.from_iterable([x["data"] for x in batch]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment