Created
August 24, 2023 21:08
-
-
Save tbenst/0f5f897147ae664c821f86f01e6b88b8 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
## | |
# Dynamic time warping of TTS audio & phoneme textgrids onto T12 audio envelopes. | |
## | |
import os, sys, glob | |
import torch | |
from tqdm import tqdm | |
import numpy as np | |
from textgrids import TextGrid | |
import torch.nn as nn, scipy | |
import torch.nn.functional as F | |
from torch.utils.data import DataLoader | |
import torchmetrics | |
import datetime | |
import zarr | |
from scipy.spatial.distance import cdist | |
import jiwer | |
import sklearn | |
from unidecode import unidecode | |
import matplotlib.pyplot as plt | |
from matplotlib.colors import SymLogNorm | |
import librosa | |
import re, sys, pickle | |
import string | |
from collections import defaultdict | |
import scipy | |
SCRIPT_DIR = os.path.dirname(os.path.dirname(os.getcwd())) | |
sys.path.append(SCRIPT_DIR) | |
from dataloaders import persist_to_file | |
from functools import lru_cache | |
from scipy.io.wavfile import write as write_wav | |
from IPython.display import Audio | |
from align import align_from_distances | |
from helpers import sentence_to_fn | |
from data_utils import read_phonemes, mel_spectrogram | |
from bark import SAMPLE_RATE as TTS_SAMPLE_RATE | |
import os, subprocess | |
# os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "backend:cudaMallocAsync" # no OOM | |
# os.environ["CUDA_VISIBLE_DEVICES"] = "" | |
hostname = subprocess.run("hostname", capture_output=True) | |
ON_SHERLOCK = hostname.stdout[:2] == b"sh" | |
# SCRIPT_DIR = "/home/tyler/code/silent_speech" | |
normalizers_file = os.path.join(SCRIPT_DIR, "normalizers.pkl") | |
mfcc_norm, emg_norm = pickle.load(open(normalizers_file,'rb')) | |
## | |
if ON_SHERLOCK: | |
T12_dir ='/scratch/users/tbenst/T12_data' | |
else: | |
T12_dir = '/data/data/T12_data' | |
datadir = os.path.join(T12_dir, 'competitionData') | |
sentences_dir = os.path.join(T12_dir, 'sentences') | |
TTS_dir = os.path.join(os.path.dirname(datadir), 'synthetic_audio', 'TTS') | |
# format like `t12.2022.04.28.mat` | |
train_files = glob.glob(datadir + '*/train/*') | |
test_files = glob.glob(datadir + '*/test/*') | |
# format like `t12.2022.04.28_sentences.mat` | |
sentences_files = glob.glob(sentences_dir + '/*') | |
# map each competitionData mat file to its corresponding sentences mat file | |
competition_file_mapping = {} | |
days = [] | |
for train_file in train_files + test_files: | |
day = train_file.split('.mat')[0].split('t12.')[1] | |
sentences_file = [sf for sf in sentences_files if day in sf][0] | |
competition_file_mapping[train_file] = sentences_file | |
days.append(day) | |
unique = np.unique(days) | |
session_mapping = dict(zip(unique, np.arange(len(unique)))) | |
print('Unique days:', len(session_mapping.keys())) | |
# INFO: these days are missing: | |
# - 2022.08.16 | |
# - 2022.09.06 | |
# - 2022.09.08 | |
# - 2022.09.13 | |
## | |
# load all mat files into dict of filename: mat | |
mat_files = {} | |
for f in tqdm(sorted(train_files + test_files + sentences_files)): | |
mat_files[f] = scipy.io.loadmat(f) | |
## | |
def get_competition_to_sentence_mapping_per_file(mat_files, train_files, test_files, competition_file_mapping): | |
"Map competitionData files to sentences files." | |
sentence_mapping_per_file = {} | |
for tf in tqdm((train_files + test_files)): | |
mat = mat_files[tf] | |
mat2 = mat_files[competition_file_mapping[tf]] | |
matching_indices = [] | |
for i, sentence in enumerate(mat['sentenceText']): | |
sentence = sentence.rstrip() # strip whitespace at end | |
matching_sentence = None | |
prev_match_idx = matching_indices[-1] if len(matching_indices) > 0 else 0 | |
# start at previous matching index to avoid matching the same sentence twice | |
for j, sentence2 in enumerate(mat2['sentences'][prev_match_idx:], start=prev_match_idx): | |
sentence2 = str(sentence2[0][0]).rstrip() # strip whitespace at end | |
if sentence == sentence2: | |
matching_sentence = sentence2 | |
matching_indices.append(j) | |
break | |
if matching_sentence is None: | |
raise Exception(f"No matching sentence found for sentence {i} in mat.\n{sentence}\n{sentence2}") | |
if len(matching_indices) != len(set(matching_indices)): | |
print(f"Error: There are {len(matching_indices) - len(set(matching_indices))} identical indices in matching_indices.") | |
# find identical index | |
for i in range(len(matching_indices)): | |
for j in range(len(matching_indices)): | |
if i != j and matching_indices[i] == matching_indices[j]: | |
print(f"Identical index ({matching_indices[i]}) at: {i} and {j}") | |
assert len(matching_indices) == len(set(matching_indices)), "There are identical indices in matching_indices." | |
sentence_mapping_per_file[tf] = matching_indices | |
return sentence_mapping_per_file | |
# map indices in competitionData mat files to indices in sentences mat files | |
# e.g. | |
# mat = '/data/data/T12_data/competitionData/test/t12.2022.07.21.mat' | |
# competition_file_mapping[mat] == '/data/data/T12_data/sentences/t12.2022.07.21_sentences.mat' | |
# sentence_mapping_per_file[mat][0] == 400 | |
# means that the first sentence in the competitionData mat file corresponds to the 400th sentence in the sentences mat file | |
competition_to_sentence_mapping_per_file = get_competition_to_sentence_mapping_per_file( | |
mat_files, train_files, test_files, competition_file_mapping) | |
## | |
# create dictionary with sentence as key, and list of tuples (file, index) as value | |
# only use competitionData mat files | |
# e.g. | |
# sentence_mapping['hello world'] == [('/data/data/T12_data/competitionData/test/t12.2022.07.21.mat', 400)] | |
def create_sentence_mapping(mat_files, competition_to_sentence_mapping_per_file, competition_file_mapping): | |
sentence_mapping = {} | |
for k, v in tqdm(competition_to_sentence_mapping_per_file.items()): | |
mat = mat_files[k] | |
mat2 = mat_files[competition_file_mapping[k]] | |
for i, idx in enumerate(v): | |
sentence = mat['sentenceText'][i] | |
sentence = sentence.rstrip() # strip whitespace at end | |
# sentence = unidecode(sentence) # remove accents | |
# sentence = re.sub(r'[^a-zA-Z0-9\s]', '', sentence) # remove punctuation | |
# sentence = sentence.lower() | |
if sentence not in sentence_mapping: | |
sentence_mapping[sentence] = [] | |
sentence_mapping[sentence].append((k, idx)) | |
return sentence_mapping | |
sentence_mapping = create_sentence_mapping(mat_files, competition_to_sentence_mapping_per_file, competition_file_mapping) | |
## | |
repeated_utterances = {k: v for k, v in sentence_mapping.items() if len(v) > 1} | |
## | |
# speaking mode of each mat file | |
speaking_modes = {} | |
for f, mat in tqdm(mat_files.items()): | |
if not 'competitionData' in f: | |
mode = mat['speakingMode'][0] | |
if mode == 'attempted nonvocal speaking': | |
speaking_modes[f] = "silent" | |
elif mode == 'attempted speaking': | |
speaking_modes[f] = "vocalized" | |
else: | |
raise Exception(f"Unknown speakingMode: {mode} from file {f}") | |
for comp_file, sentence_file in tqdm(competition_file_mapping.items()): | |
speaking_modes[comp_file] = speaking_modes[sentence_file] | |
speaking_modes | |
## | |
# what fraction of silent sentences are also vocalized? | |
silent_count = 0 | |
vocalized_count = 0 | |
parallel_count = 0 | |
for sentence, files in tqdm(sentence_mapping.items()): | |
is_silent = False | |
is_vocalized = False | |
for f, idx in files: | |
if speaking_modes[f] == "silent": | |
is_silent = True | |
elif speaking_modes[f] == "vocalized": | |
is_vocalized = True | |
else: | |
raise Exception(f"Unknown speakingMode: {speaking_modes[f]} from file {f}") | |
if is_silent and is_vocalized: | |
parallel_count += 1 | |
elif is_silent: | |
silent_count += 1 | |
elif is_vocalized: | |
vocalized_count += 1 | |
print(f"{silent_count=}, {vocalized_count=}, {parallel_count=}") | |
print(f"Percent of silent sentences that are also vocalized: {parallel_count / silent_count* 100:.2f}%") | |
print(f"Number of repeated utterances: {len(repeated_utterances)}") | |
## | |
# number of sentences in train and test | |
print(f"Number of utterances in train: {sum([len(v) for k, v in competition_to_sentence_mapping_per_file.items() if 'competitionData/train' in k])}") | |
print(f"Number of utterances in test: {sum([len(v) for k, v in competition_to_sentence_mapping_per_file.items() if 'competitionData/test' in k])}") | |
num_utterances_in_both = 0 | |
for k,v in repeated_utterances.items(): | |
has_train = False | |
has_test = False | |
for f, idx in v: | |
if 'competitionData/train' in f: | |
has_train = True | |
elif 'competitionData/test' in f: | |
has_test = True | |
if has_train and has_test: | |
num_utterances_in_both += 1 | |
print(f"Number of sentences in both train & test: {num_utterances_in_both}") | |
## | |
# align competitionData neural data to sentences neural data | |
def window_middle_signal(signal): | |
"""Extract the middle 50% of a signal.""" | |
length = len(signal) | |
start_idx = int(0.25 * length) | |
end_idx = int(0.75 * length) | |
return signal[start_idx:end_idx] | |
def compute_offset_1d(signal1, reference_signal): | |
# Window the signal to get the middle 50% | |
# windowed_signal1 = window_middle_signal(signal1) | |
# Compute 1D cross-correlation | |
# cross_corr_1d = np.correlate(reference_signal, windowed_signal1, mode='full') | |
cross_corr_1d = np.correlate(reference_signal, signal1, mode='full') | |
# Get the index of the peak of the 1D cross-correlation | |
idx_peak = np.argmax(cross_corr_1d) | |
print(f"{idx_peak}") | |
# Compute the offset. The offset represents how much to shift signal1 | |
# to align it with the reference signal. | |
offset = (len(signal1) - 1) - idx_peak | |
return offset | |
# for mat_file, sentenceIdxs in tqdm(list(competition_to_sentence_mapping_per_file.items())[10:]): | |
for mat_file, sentenceIdxs in tqdm(list(competition_to_sentence_mapping_per_file.items())[0:]): | |
mat = mat_files[mat_file] | |
sentence_file = competition_file_mapping[mat_file] | |
sentence_mat = mat_files[sentence_file] | |
comp_mat = mat_files[mat_file] | |
go_cues = sentence_mat['goTrialEpochs'] | |
for compIdx, sentenceIdx in enumerate(sentenceIdxs): | |
# TODO: why is this failing..? | |
try: | |
sentence_dat = sentence_mat['spikePow'][go_cues[sentenceIdx,0]:go_cues[sentenceIdx,1]] | |
comp_dat = comp_mat['spikePow'][0,compIdx] | |
assert np.all(np.isclose(sentence_dat, comp_dat)) | |
except Exception as e: | |
print(f"Neural data does not match for sentence {sentenceIdx} in file {sentence_file} and compIdx {compIdx} in file {mat_file}") | |
offset = compute_offset_1d(sentence_dat.mean(axis=1), comp_dat.mean(axis=1)) | |
# raise e | |
sentence_dat = sentence_dat[offset:] | |
# known issue sadly, two sentences can't be fixed like this | |
# assert np.all(np.isclose(sentence_dat, comp_dat)), f"still doesn't match with offset {offset}" | |
# print(f"fixed alignment to {mat_file} with offset {offset}") | |
# break | |
# go_cue = go_cues[idx] | |
# neural_data = mat['spikePow'] | |
# offset = 23 | |
## | |
# show the bad alignment / skipped samples | |
plt.plot(comp_dat.mean(axis=1)[0:]) | |
# plt.plot(sentence_dat.mean(axis=1)[offset:]) | |
plt.plot(sentence_dat.mean(axis=1)[0:]) | |
plt.legend(["competitionData", "sentences"]) | |
plt.title(f"Neural data for sentence {sentenceIdx} in sentences and sentence {compIdx} in file\n{mat_file}") | |
plt.ylabel("mean spike power") | |
plt.xlabel("timestep (20ms)") | |
## | |
def compute_audio_envelope(audio, sample_rate=44100, frame_size_ms=20, hop_size_ms=20): | |
""" | |
Compute the audio envelope using the FFT magnitude spectrum. | |
Parameters: | |
- audio: 1D numpy array of audio samples. | |
- sample_rate: Sampling rate of the audio (default is 44.1kHz). | |
- frame_size_ms: Size of the frames in milliseconds for FFT computation (default is 20ms). | |
- hop_size_ms: Step size between frames in milliseconds (default is 10ms). | |
Returns: | |
- envelope: Audio envelope as a numpy array. | |
""" | |
# Calculate frame size and hop size in samples | |
frame_size_samples = int(sample_rate * frame_size_ms / 1000) | |
hop_size_samples = int(sample_rate * hop_size_ms / 1000) | |
# Apply a window function (Hann window) | |
window = np.hanning(frame_size_samples) | |
# Compute the number of frames | |
num_frames = 1 + (len(audio) - frame_size_samples) // hop_size_samples | |
# Compute the envelope | |
envelope = np.zeros(num_frames) | |
for i in range(num_frames): | |
start_idx = i * hop_size_samples | |
end_idx = start_idx + frame_size_samples | |
frame = audio[start_idx:end_idx] * window | |
# Compute magnitude spectrum of the frame | |
magnitude_spectrum = np.abs(np.fft.rfft(frame)) | |
envelope[i] = np.mean(magnitude_spectrum) | |
return envelope | |
i = 0 | |
for mat_file, sentenceIdxs in tqdm(competition_to_sentence_mapping_per_file.items()): | |
mat = mat_files[mat_file] | |
sentence_file = competition_file_mapping[mat_file] | |
print(sentence_file) | |
sentence_mat = mat_files[sentence_file] | |
i += 1 | |
if i == 21: | |
break | |
blockIdx = 0 | |
aud = np.squeeze(sentence_mat['audio'][0,blockIdx]) | |
aud = compute_audio_envelope(aud, sample_rate=30000, frame_size_ms=20, hop_size_ms=20) | |
fig, axs = plt.subplots(2, 1, figsize=(10, 10)) | |
axs[0].plot(sentence_mat['audioEnvelope'][:10000]) | |
axs[0].set_title('Reference Audio Envelope') | |
axs[1].plot(aud[:10000]) | |
axs[1].set_title('Computed Audio Envelope') | |
fig.suptitle(sentence_file) | |
plt.show() | |
## | |
S = 300 | |
N = 450 | |
fig, axs = plt.subplots(2, 1, figsize=(10, 10)) | |
axs[0].plot(sentence_mat['audioEnvelope'][S:N]) | |
axs[0].set_title('Reference Audio Envelope') | |
axs[1].plot(aud[S:N]) | |
axs[1].set_title('Computed Audio Envelope') | |
plt.show() | |
## | |
S = 0 | |
S = 220 | |
N = 400 | |
# N = 10000 | |
# num_mels = 42 | |
num_mels = 80 | |
mspec = mel_spectrogram(torch.tensor(sentence_mat['audio'][0,blockIdx], dtype=torch.float32), | |
# n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax | |
2048, num_mels, 30000, 30000//50, 30000//25, 0, 8000, center=False) | |
mspec = mspec.squeeze().log2().numpy() | |
# mspec = mfcc_norm.normalize(mspec.T).T # kinda weird | |
fig, axs = plt.subplots(2, 1, figsize=(10, 10)) | |
axs[0].imshow(sentence_mat['audioFeatures'][S:N].T, aspect='auto', origin='lower') | |
axs[0].set_title('Audio Features') | |
axs[0].set_ylabel("Mel Frequency Cepstral Coefficients") | |
axs[1].imshow(mspec[:,S:N], aspect='auto', origin='lower') | |
axs[1].set_title('Computed mel Spectrogram') | |
axs[1].set_ylabel("Mel Frequency Cepstral Coefficients") | |
axs[1].set_xlabel("# frames (20ms each)") | |
plt.show() | |
## | |
# concatenate all audio arrays | |
# audio = np.concatenate([np.squeeze(sentence_mat['audio'][0,blockIdx]) for blockIdx in range(len(sentence_mat['audio'][0]))]) | |
# audio_envelope = compute_audio_envelope(audio, sample_rate=30000, frame_size_ms=20) | |
# mspec = mel_spectrogram(torch.tensor(audio[None], dtype=torch.float32), | |
# # n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax | |
# 2048, 42, 30000, 30000//50, 30000//50, 0, 8000, center=False).squeeze() | |
def load_TTS_data(sentence, directory=TTS_dir, ms_per_frame=10): | |
"""Load corresponding TTS audio and phoneme labels for a given sentence.""" | |
tts_audio_path = sentence_to_fn(sentence, directory, ext=".wav") | |
textgrid = sentence_to_fn(sentence, directory, ext=".TextGrid") | |
# even though neural data at 20ms bins, run alignment at 10ms bins | |
phonemes = read_phonemes(textgrid, ms_per_frame=ms_per_frame) | |
audio, sample_rate = librosa.load(tts_audio_path) | |
return audio, phonemes, sample_rate | |
# TODO: get phoneme labels for each timestep T of competitionData | |
# Main loop for processing sentences | |
# for each vocalized sentence, get audio spectrogram | |
# read parallel TTS audio + phoneme labels | |
# Run DTW on spectrograms of each vocalized sentence (audio) to TTS audio | |
# accumulate phoneme labels for each timestep of neural data | |
def go_cue_to_block_and_rel_index(go_cue_idx, block_start_idxs): | |
"Given a go_cue index, return the corresponding block index and relative index." | |
block_idx = np.where(block_start_idxs <= go_cue_idx)[0][-1] | |
rel_idx = go_cue_idx - block_start_idxs[block_idx] | |
return block_idx, rel_idx | |
# compare to audio envelope to check if correct | |
idx = 28 | |
block_start_idxs = np.concatenate([[0], 1 + np.where(np.diff(sentence_mat['blockNum'][:,0]))[0]]) | |
go_cue_idx = sentence_mat['goTrialEpochs'][idx,0] | |
block_idx, relIdx = go_cue_to_block_and_rel_index(go_cue_idx, block_start_idxs) | |
plt.plot(sentence_mat['audioEnvelope'][go_cue_idx:go_cue_idx+150]) | |
aud = sentence_mat['audio'][0,block_idx][0] | |
aud = compute_audio_envelope(aud, sample_rate=30000, frame_size_ms=20) | |
aud = aud[relIdx:relIdx+150] | |
print(aud.shape) | |
plt.plot(aud) | |
# not matching anymore, prob just deleted some things..? | |
plt.legend(["audioEnvelope reference", "audio check"]) | |
## | |
# check audio length | |
for mat_file in sentences_files: | |
sentence_mat = mat_files[mat_file] | |
T = sentence_mat['spikePow'].shape[0] | |
neural_seconds = T * 20 / 1000 | |
nAudio = np.sum([m[0].shape[0] for m in sentence_mat['audio'][0]]) | |
audio_seconds = nAudio / 30000 | |
if np.abs(neural_seconds - audio_seconds) > 0.5: | |
print(f"==== {mat_file} ====") | |
print(f"{T=}\n{neural_seconds=}") | |
print(f"{nAudio=}\n{audio_seconds=}") | |
break | |
# sentences/t12.2022.06.28_sentences.mat audio block 5 has length of zero | |
## | |
########################## THIS IS THE KEY FUNCTION ######################################## | |
# save a npz file | |
# iterate sentences mat files | |
ms_per_frame = 20 | |
nframes_per_sec = 1000 // ms_per_frame | |
mat_sentences = {} | |
mat_block = {} | |
mat_mspecs = {} | |
mat_tts_mspecs = {} | |
mat_tts_phonemes = {} | |
mat_aligned_mspecs = {} | |
mat_aligned_phonemes = {} | |
mat_spikePow = {} | |
mat_tx1 = {} | |
mat_tx2 = {} | |
mat_tx3 = {} | |
mat_tx4 = {} | |
mat_speakingMode = {} | |
mat_audioEnvelope = {} | |
mat_dataset_partition = {} | |
total_T = 0 | |
n_sentences = 0 | |
# for mat_file in sentences_files: | |
for mat_file in competition_file_mapping.keys(): | |
comp_mat = mat_files[mat_file] | |
# total_T += sentence_mat['spikePow'].shape[0] | |
total_T += np.sum(np.array([m.shape[0] for m in comp_mat['spikePow'][0]])) | |
# n_sentences += len(sentence_mat['sentences']) | |
n_sentences += len(comp_mat['sentenceText']) | |
# npz = np.load('/data/data/T12_data/synthetic_audio/2023-08-20_T12_dataset.npz', | |
# allow_pickle=True) | |
# np.sum([n.shape[0] for n in npz['spikePow']]) / len(npz['spikePow']) | |
# 311.06 | |
# rolling window of 20 sentences. | |
# Avg sentence is: 312 during goCue and 555.8 in total on average | |
# 20 sentences rolling z-score idea from Willett et al. 2023 | |
# my implementation a bit different as constant in time steps | |
window_size = int(np.ceil(total_T/n_sentences)) * 20 | |
## | |
def moving_mean(x, window): | |
"For T x N matrix, compute the rolling mean over window timesteps." | |
x_mean = torch.mean(x.unfold(0,window,1), dim=2) | |
# print(f"{x_mean.shape=}") | |
# use first mean for first window-1 timesteps | |
# technically acausal for first window-1 timesteps | |
x_new = torch.zeros_like(x) | |
x_new[:window-1] = x_mean[0] | |
x_new[window-1:] = x_mean | |
return x_new | |
def moving_std(x, window): | |
"For T x N matrix, compute the rolling mean over window timesteps." | |
x_std = torch.std(x.unfold(0,window,1), dim=2) | |
# use first mean for first window-1 timesteps | |
# technically acausal for first window-1 timesteps | |
x_new = torch.zeros_like(x) | |
x_new[:window-1] = x_std[0] | |
x_new[window-1:] = x_std | |
return x_new | |
def moving_zscore(x, window, eps=1e-6): | |
"For T x N matrix, compute the rolling z-score over window timesteps." | |
x_mean = moving_mean(x, window) | |
x_std = moving_std(x, window) | |
zscored = (x - x_mean) / (x_std + eps) | |
return zscored | |
# sentence_mat = mat_files[sentences_files[0]] | |
sentence_mat = mat_files[sentences_files[-1]] | |
# movet = moving_mean(sentence_mat['spikePow'], window_size) | |
s,e = sentence_mat['goTrialEpochs'][0] | |
# sp = sentence_mat['spikePow'][:window_size] | |
sp = sentence_mat['tx1'][:window_size] | |
spc = torch.from_numpy(sp).float().cuda() | |
# fig, axs = plt.subplots(4, 1, figsize=(10, 10)) | |
# axs[0].plot(sp[:,0]) | |
# axs[1].plot(spc[:,0].cpu().numpy()) | |
# axs[2].plot(spc.unfold(0,500,1)[0,0].cpu().numpy()) | |
# axs[3].plot(spc[:,0].cpu().numpy()) | |
spm = moving_mean(spc, window_size).cpu().numpy() | |
spz = moving_zscore(spc, window_size,1).cpu().numpy() | |
spm = moving_std(spc, window_size).cpu().numpy() | |
spg = moving_zscore( | |
torch.from_numpy(scipy.ndimage.gaussian_filter1d(sp.astype(np.float32),2,axis=0)).cuda(), | |
window_size,1).cpu().numpy() | |
fig, axs = plt.subplots(4, 1, figsize=(10, 10)) | |
axs[0].plot(sp[:,0]) | |
axs[0].set_title("original") | |
axs[1].plot((sp[:,0]-np.mean(sp[:,0]))/sp[:,0].std()) | |
axs[1].set_title("proper zscore") | |
axs[2].plot(spz[:,0]) | |
axs[2].set_title("moving zscore") | |
axs[3].plot(spg[:,0]) | |
axs[3].set_title("gaussian smoothing") | |
## | |
# plt.imshow(sp[s:e]) | |
# plt.imshow((spikePow[s:e]-sp.mean(axis=0)[None])/sp.std(axis=0)[None]) | |
# plt.colorbar() | |
## | |
saw_bad_audio = False | |
max_bad_sentence_count = 4 # 4 known bad sentences | |
bad_sentence_count = 0 | |
# for mat_file in tqdm(sentences_files): | |
# for mat_file in sentences_files: | |
for comp_file, sentence_file in list(competition_file_mapping.items()): | |
comp_mat = mat_files[comp_file] | |
sentence_mat = mat_files[sentence_file] | |
is_train = '/competitionData/train/' in comp_file | |
is_val = '/competitionData/test/' in comp_file | |
assert is_train or is_val | |
assert not (is_train and is_val) | |
mat_name = os.path.split(comp_file)[-1] | |
block_start_idxs = np.concatenate([ | |
[0], | |
1 + np.where(np.diff(sentence_mat['blockNum'].squeeze()))[0] | |
]) | |
block_end_idxs = np.concatenate([ | |
block_start_idxs[1:], | |
[len(sentence_mat['blockNum'].squeeze())] | |
]) | |
# last_block_idx = sentence_mat['blockNum'][:,0][-1] # test set (doesn't always start at 1 / skips numbers) | |
last_block_idx = len(sentence_mat['blockList']) - 1 # last block is test set | |
audio_block = [] | |
for i in range(len(sentence_mat['audio'][0])): | |
aud = sentence_mat['audio'][0,i][0] | |
if aud.shape[0] == 0: | |
bidx = sentence_mat['blockList'].squeeze()[i] | |
print(f"WARNING: got bad audio for block {bidx} (idx {i}) in {sentence_file}") | |
# assert not saw_bad_audio, "there should only be one..." | |
saw_bad_audio = True | |
audio_block.append(None) | |
else: | |
aud = librosa.util.buf_to_float(aud) | |
aud = np.clip(aud, -1, 1) | |
audio_block.append(aud) | |
# audio_block = [librosa.util.buf_to_float(sentence_mat['audio'][0,i][0]) for i in range(len(sentence_mat['audio'][0]))] | |
# volume_block = [compute_audio_envelope(aud, sample_rate=30000, frame_size_ms=20) for aud in audio_block] | |
volume_block = [] | |
for aud in audio_block: | |
if aud is None: | |
volume_block.append(None) | |
else: | |
volume_block.append(compute_audio_envelope(aud, sample_rate=30000, frame_size_ms=20)) | |
# for a in audio_block: | |
# print(f"AB min: {np.min(a)}, max: {np.max(a)}") | |
mspec_block = [] | |
for aud in audio_block: | |
if aud is None: | |
mspec_block.append(None) | |
else: | |
au = torch.tensor(aud[None], dtype=torch.float32).cuda().clip(-1,1) | |
mspec_block.append(mel_spectrogram(au, | |
# n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax | |
2048, 80, 30000, 30000//nframes_per_sec, 30000//(nframes_per_sec//2), 0, 8000, center=False).squeeze().T) | |
# mspec_block = [mel_spectrogram(torch.tensor(aud[None], dtype=torch.float32).cuda(), | |
# # n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax | |
# 2048, 80, 30000, 30000//nframes_per_sec, 30000//(nframes_per_sec//2), 0, 8000, center=False).squeeze() | |
# for aud in audio_block] | |
if speaking_modes[comp_file] == "silent": | |
mat_speakingMode[mat_name] = "silent" | |
elif speaking_modes[comp_file] == "vocalized": | |
mat_speakingMode[mat_name] = "vocalized" | |
else: | |
raise Exception(f"Unknown speakingMode: {speaking_modes[comp_file]} from file {comp_file}") | |
assert len(audio_block) < 100 | |
# always append | |
sentences = [] | |
spikePow = [] | |
tx1 = [] | |
tx2 = [] | |
tx3 = [] | |
tx4 = [] | |
dataset_partition = [] | |
block = [] | |
# try to append | |
mspecs = [] | |
tts_mspecs = [] | |
tts_phonemes = [] | |
aligned_mspecs = [] | |
aligned_phonemes = [] | |
audioEnvelope = [] | |
session_spikePow = scipy.ndimage.gaussian_filter1d(sentence_mat['spikePow'],2,axis=0) | |
session_tx1 = scipy.ndimage.gaussian_filter1d(sentence_mat['tx1'],2,axis=0) | |
session_tx2 = scipy.ndimage.gaussian_filter1d(sentence_mat['tx2'],2,axis=0) | |
session_tx3 = scipy.ndimage.gaussian_filter1d(sentence_mat['tx3'],2,axis=0) | |
session_tx4 = scipy.ndimage.gaussian_filter1d(sentence_mat['tx4'],2,axis=0) | |
session_spikePow = torch.from_numpy(session_spikePow).float().cuda() | |
session_tx1 = torch.from_numpy(session_tx1).float().cuda() | |
session_tx2 = torch.from_numpy(session_tx2).float().cuda() | |
session_tx3 = torch.from_numpy(session_tx3).float().cuda() | |
session_tx4 = torch.from_numpy(session_tx4).float().cuda() | |
spikePow_mean = moving_mean(session_spikePow, window_size) | |
tx1_mean = moving_mean(session_tx1, window_size) | |
tx2_mean = moving_mean(session_tx2, window_size) | |
tx3_mean = moving_mean(session_tx3, window_size) | |
tx4_mean = moving_mean(session_tx4, window_size) | |
spikePow_std = moving_std(session_spikePow, window_size) + 1 | |
tx1_std = moving_std(session_tx1, window_size) + 1 | |
tx2_std = moving_std(session_tx2, window_size) + 1 | |
tx3_std = moving_std(session_tx3, window_size) + 1 | |
tx4_std = moving_std(session_tx4, window_size) + 1 | |
# session_spikePow = session_spikePow.cpu().numpy() | |
# session_tx1 = session_tx1.cpu().numpy() | |
# session_tx2 = session_tx2.cpu().numpy() | |
# session_tx3 = session_tx3.cpu().numpy() | |
# session_tx4 = session_tx4.cpu().numpy() | |
# TODO: try taking sqrt before zscore | |
for compIdx, sentenceIdx in enumerate(tqdm(competition_to_sentence_mapping_per_file[comp_file])): | |
sentence = sentence_mat['sentences'][sentenceIdx][0][0] | |
sentence = sentence.rstrip() | |
go_cue_idx = sentence_mat['goTrialEpochs'][sentenceIdx] | |
sentence_spikePow = session_spikePow[go_cue_idx[0]:go_cue_idx[1]] | |
# sanity check alignment | |
sentence_dat = sentence_mat['spikePow'][go_cue_idx[0]:go_cue_idx[1]] | |
comp_dat = comp_mat['spikePow'][0,compIdx] | |
try: | |
assert np.all(np.isclose(sentence_dat, comp_dat)) # also errors if not same dim, hence try block | |
except Exception: | |
bad_sentence_count += 1 | |
if bad_sentence_count > max_bad_sentence_count: | |
raise Exception(f"Too many bad un-aligned sentences between {comp_file} and {sentence_file}") | |
sentence_tx1 = session_tx1[go_cue_idx[0]:go_cue_idx[1]] | |
sentence_tx2 = session_tx2[go_cue_idx[0]:go_cue_idx[1]] | |
sentence_tx3 = session_tx3[go_cue_idx[0]:go_cue_idx[1]] | |
sentence_tx4 = session_tx4[go_cue_idx[0]:go_cue_idx[1]] | |
wi = go_cue_idx[1] - 1 | |
sentence_spikePow = (sentence_spikePow - spikePow_mean[wi]) / spikePow_std[wi] | |
sentence_tx1 = (sentence_tx1 - tx1_mean[wi]) / tx1_std[wi] | |
sentence_tx2 = (sentence_tx2 - tx2_mean[wi]) / tx2_std[wi] | |
sentence_tx3 = (sentence_tx3 - tx3_mean[wi]) / tx3_std[wi] | |
sentence_tx4 = (sentence_tx4 - tx4_mean[wi]) / tx4_std[wi] | |
sentence_spikePow = sentence_spikePow.cpu().numpy() | |
sentence_tx1 = sentence_tx1.cpu().numpy() | |
sentence_tx2 = sentence_tx2.cpu().numpy() | |
sentence_tx3 = sentence_tx3.cpu().numpy() | |
sentence_tx4 = sentence_tx4.cpu().numpy() | |
sentences.append(sentence) | |
spikePow.append(sentence_spikePow) | |
tx1.append(sentence_tx1) | |
tx2.append(sentence_tx2) | |
tx3.append(sentence_tx3) | |
tx4.append(sentence_tx4) | |
block.append(sentence_mat['blockNum'][go_cue_idx[0]]) | |
# note block_idx here starts at 0, but doesn't in sentence_mat['blockList'] | |
block_idx, startIdx = go_cue_to_block_and_rel_index(go_cue_idx[0], block_start_idxs) | |
block_idx2, stopIdx = go_cue_to_block_and_rel_index(go_cue_idx[1], block_start_idxs) | |
if stopIdx == 0: | |
block_idx2, stopIdx = go_cue_to_block_and_rel_index(go_cue_idx[1]-1, block_start_idxs) | |
stopIdx += 1 | |
assert block_idx == block_idx2 | |
if block_idx == last_block_idx: | |
assert is_val | |
dataset_partition.append("val") | |
else: | |
try: | |
assert is_train | |
dataset_partition.append("train") | |
except AssertionError: | |
print(f"{comp_file} should be train data, but we will use it as val data to match") | |
dataset_partition.append("val") | |
try: | |
tts_audio, tts_phones, sample_rate = load_TTS_data(sentence, ms_per_frame=ms_per_frame) | |
except FileNotFoundError: | |
print("Skipping as could not read file (prob TextGrid) for sentence: ", sentence) | |
mspecs.append(None) | |
tts_mspecs.append(None) | |
tts_phonemes.append(None) | |
aligned_mspecs.append(None) | |
aligned_phonemes.append(None) | |
audioEnvelope.append(None) | |
continue | |
# print(f"TTS min audio: {np.min(tts_audio)}, max audio: {np.max(tts_audio)}") | |
tts_volume = compute_audio_envelope(tts_audio, sample_rate=sample_rate, frame_size_ms=20) | |
tts_au = torch.tensor(tts_audio, dtype=torch.float32).cuda()[None].clip(-1,1) | |
tts_mspec = mel_spectrogram(tts_au, | |
2048, 80, sample_rate, sample_rate//nframes_per_sec, sample_rate//(nframes_per_sec//2), 0, 8000, center=False).squeeze().T | |
tts_mspecs.append(tts_mspec.cpu().numpy()) | |
tts_phonemes.append(tts_phones) | |
if mspec_block[block_idx] is None: | |
# we're missing audio data | |
audioEnvelope.append(None) | |
mspecs.append(None) | |
aligned_mspecs.append(None) | |
aligned_phonemes.append(None) | |
continue | |
t12_mspec = mspec_block[block_idx][startIdx:stopIdx] | |
t12_volume = volume_block[block_idx][startIdx:stopIdx] | |
if speaking_modes[comp_file] == "vocalized": | |
# finally, run dynamic time warping between t12_mspec and tts_mspec | |
# good! | |
# dists = cdist(t12_mspec.T, tts_mspec.T) | |
dists = torch.cdist(t12_mspec, tts_mspec) | |
# bad... | |
# dists = 1 - torchmetrics.functional.pairwise_cosine_similarity(t12_mspec.T, tts_mspec.T).cpu().numpy() | |
# okay... | |
# dists = cdist(t12_volume[None].T, tts_volume[None].T) | |
alignment = align_from_distances(dists.cpu().numpy()) | |
audioEnvelope.append(t12_volume) | |
mspecs.append(t12_mspec.cpu().numpy()) | |
aligned_mspecs.append(tts_mspec[alignment].cpu().numpy()) | |
aligned_phonemes.append(tts_phones[alignment]) | |
else: | |
audioEnvelope.append(None) | |
mspecs.append(None) | |
aligned_mspecs.append(None) | |
aligned_phonemes.append(None) | |
# raise Exception("stop here") | |
mat_sentences[mat_name] = sentences | |
mat_block[mat_name] = block | |
mat_mspecs[mat_name] = mspecs | |
mat_tts_mspecs[mat_name] = tts_mspecs | |
mat_tts_phonemes[mat_name] = tts_phonemes | |
mat_aligned_mspecs[mat_name] = aligned_mspecs | |
mat_aligned_phonemes[mat_name] = aligned_phonemes | |
mat_spikePow[mat_name] = spikePow | |
mat_tx1[mat_name] = tx1 | |
mat_tx2[mat_name] = tx2 | |
mat_tx3[mat_name] = tx3 | |
mat_tx4[mat_name] = tx4 | |
mat_audioEnvelope[mat_name] = audioEnvelope | |
mat_dataset_partition[mat_name] = dataset_partition | |
## | |
for mtest in sorted(list(filter(lambda x: "test" in x, competition_to_sentence_mapping_per_file.keys()))): | |
stest = competition_file_mapping[mtest] | |
mat_test = mat_files[mtest] | |
sen_test = mat_files[stest] | |
fn = os.path.join(*mtest.split('/')[-2:]) | |
print(f"===== {fn} =====") | |
print("Second to last block: ", sen_test['blockTypes'][-2].squeeze()) | |
print("Last block: ", sen_test['blockTypes'][-1].squeeze()) | |
## | |
sentence_file_mapping = defaultdict(list) | |
for k,v in competition_file_mapping.items(): | |
sentence_file_mapping[v].append(k) | |
for sentence_file, comp_files in sorted(list(sentence_file_mapping.items())): | |
comp_mat0 = mat_files[comp_files[0]] | |
comp_mat1 = mat_files[comp_files[1]] | |
sentence_mat = mat_files[sentence_file] | |
print(f"===== {comp_files[0].split('/')[-1]} =====") | |
print(f"{''.join(comp_files[0].split('/')[-2])} length: {len(comp_mat0['sentenceText'])}") | |
print(f"{''.join(comp_files[1].split('/')[-2])} length: {len(comp_mat1['sentenceText'])}") | |
print(f"{''.join(sentence_file.split('/')[-2])} length: {len(sentence_mat['sentences'])}") | |
## | |
# save to Zarr | |
# flatten to 1D arrays of length num_sentences | |
num_sentences_per_mat = [] | |
flat_block = [] | |
flat_session = [] | |
flat_dataset_partition = [] | |
flat_sentences = [] | |
flat_mspecs = [] | |
flat_tts_mspecs = [] | |
flat_tts_phonemes = [] | |
flat_aligned_mspecs = [] | |
flat_aligned_phonemes = [] | |
flat_spikePow = [] | |
flat_tx1 = [] | |
flat_tx2 = [] | |
flat_tx3 = [] | |
flat_tx4 = [] | |
for mat_file, v in mat_mspecs.items(): | |
nsentences = len(v) | |
num_sentences_per_mat.append(nsentences) | |
flat_session.extend([mat_file] * nsentences) | |
flat_mspecs.extend(mat_mspecs[mat_file]) | |
assert len(mat_sentences[mat_file]) == nsentences | |
flat_sentences.extend(mat_sentences[mat_file]) | |
assert(len(mat_block[mat_file]) == nsentences) | |
flat_block.extend(mat_block[mat_file]) | |
assert len(mat_dataset_partition[mat_file]) == nsentences | |
flat_dataset_partition.extend(mat_dataset_partition[mat_file]) | |
assert len(mat_tts_mspecs[mat_file]) == nsentences | |
flat_tts_mspecs.extend(mat_tts_mspecs[mat_file]) | |
assert len(mat_tts_phonemes[mat_file]) == nsentences | |
flat_tts_phonemes.extend(mat_tts_phonemes[mat_file]) | |
assert len(mat_aligned_mspecs[mat_file]) == nsentences | |
flat_aligned_mspecs.extend(mat_aligned_mspecs[mat_file]) | |
assert len(mat_aligned_phonemes[mat_file]) == nsentences | |
flat_aligned_phonemes.extend(mat_aligned_phonemes[mat_file]) | |
assert len(mat_spikePow[mat_file]) == nsentences | |
flat_spikePow.extend(mat_spikePow[mat_file]) | |
assert len(mat_tx1[mat_file]) == nsentences | |
flat_tx1.extend(mat_tx1[mat_file]) | |
assert len(mat_tx2[mat_file]) == nsentences | |
flat_tx2.extend(mat_tx2[mat_file]) | |
assert len(mat_tx3[mat_file]) == nsentences | |
flat_tx3.extend(mat_tx3[mat_file]) | |
assert len(mat_tx4[mat_file]) == nsentences | |
flat_tx4.extend(mat_tx4[mat_file]) | |
## | |
cur_date = datetime.datetime.now().strftime("%Y-%m-%d") | |
path = os.path.join(os.path.dirname(datadir), "synthetic_audio", | |
f"{cur_date}_T12_dataset_comp_split.npz") | |
# mdict = { | |
# "session": session, "sentences": mat_sentences, | |
# "mspecs": flat_mspecs, "aligned_mspecs": flat_aligned_mspecs, "aligned_phonemes": flat_aligned_phonemes, | |
# "spikePow": flat_spikePow, "tx1": flat_tx1, "tx2": flat_tx2, "tx3": flat_tx3, "tx4": flat_tx4, | |
# } | |
mdict = { | |
"session": flat_session, "dataset_partition": flat_dataset_partition, "sentences": flat_sentences, | |
"block": flat_block, | |
"mspecs": flat_mspecs, "tts_mspecs": flat_tts_mspecs, "tts_phonemes": flat_tts_phonemes, | |
"aligned_tts_mspecs": flat_aligned_mspecs, "aligned_phonemes": flat_aligned_phonemes, | |
"spikePow": flat_spikePow, "tx1": flat_tx1, "tx2": flat_tx2, "tx3": flat_tx3, "tx4": flat_tx4, | |
} | |
mdict_arr = {} | |
for k,v in mdict.items(): | |
try: | |
mdict_arr[k] = np.array(v) | |
except: | |
# support ragged array | |
mdict_arr[k] = np.array(v, dtype=object) | |
np.savez(path, **mdict_arr) | |
# may not work | |
# zarr.save_group(path, **mdict) | |
# Prob don't need to run script below here unless exploring data | |
print(f"Saved T12 dataset to {path}") | |
# exit(0) | |
## | |
exit(0) | |
## | |
# spot check 6/28 since missing audio block 5 | |
# not sure if okay or not | |
# DTW alignment is not great overall... | |
idx = -24 | |
idx = -4 | |
fig, axs = plt.subplots(4, 1, figsize=(10, 9), sharex=True) | |
mat = mat_files['/data/data/T12_data/sentences/t12.2022.06.28_sentences.mat'] | |
ae = mat['audioEnvelope'] | |
start, stop = mat['goTrialEpochs'][idx] | |
axs[0].plot(ae[start:stop]) | |
axs[0].set_title('audio volume') | |
axs[1].imshow(mat['audioFeatures'][start:stop].T, aspect='auto', origin='lower') | |
axs[1].set_title('audioFeatures') | |
axs[2].imshow(mat_mspecs['t12.2022.06.28_sentences.mat'][idx], aspect='auto', origin='lower') | |
axs[2].set_title('T12 mspec') | |
axs[3].imshow(mat_aligned_mspecs['t12.2022.06.28_sentences.mat'][idx], aspect='auto', origin='lower') | |
axs[3].set_title('aligned TTS mspec') | |
## | |
fig, axs = plt.subplots(3, 1, figsize=(10, 6), sharey=True) | |
axs[0].imshow(t12_mspec.cpu().numpy(), aspect='auto', origin='lower') | |
axs[0].set_title('t12 mspec') | |
axs[1].imshow(tts_mspec[alignment].cpu().numpy(), aspect='auto', origin='lower') | |
axs[1].set_title('aligned TTS mspec') | |
axs[2].imshow(tts_mspec.cpu().numpy(), aspect='auto', origin='lower') | |
axs[2].set_title('TTS mspec') | |
plt.tight_layout() | |
plt.show() | |
## | |
plt.plot(alignment) | |
plt.title("DTW alignment") | |
plt.xlabel("T12 index") | |
plt.ylabel("TTS index") | |
## | |
fig, axs = plt.subplots(2, 1, figsize=(10, 6), sharex=True) | |
axs[0].plot(tts_phonemes) | |
axs[0].set_title("TTS phonemes") | |
axs[0].set_ylabel("phoneme") | |
axs[1].imshow(tts_mspec, aspect='auto', origin='lower') | |
axs[1].set_ylabel("MFCC") | |
axs[1].set_xlabel("time (20ms)") | |
plt.show() | |
## | |
fig, axs = plt.subplots(3, 1, figsize=(10, 9), sharex=True) | |
axs[0].imshow(t12_mspec, aspect='auto', origin='lower') | |
axs[0].set_ylabel("MFCC") | |
axs[0].set_title("T12 mspec") | |
axs[1].plot(tts_phonemes[alignment]) | |
axs[1].set_title("T12 (aligned) phonemes") | |
axs[1].set_ylabel("phoneme") | |
axs[2].imshow(tts_mspec[alignment], aspect='auto', origin='lower') | |
axs[2].set_title('aligned TTS mspec') | |
axs[2].set_ylabel("MFCC") | |
axs[2].set_xlabel("time (20ms)") | |
plt.tight_layout() | |
plt.show() | |
## | |
fig, axs = plt.subplots(3, 1, figsize=(10, 9), sharex=True) | |
axs[0].imshow(t12_mspec.cpu().numpy(), aspect='auto', origin='lower') | |
axs[0].set_ylabel("MFCC") | |
axs[0].set_title("T12 mspec") | |
axs[1].plot(mat_aligned_phonemes['/data/data/T12_data/sentences/t12.2022.05.05_sentences.mat'][-1]) | |
axs[1].set_title("T12 (aligned) phonemes") | |
axs[1].set_ylabel("phoneme") | |
axs[2].imshow(mat_aligned_mspecs['/data/data/T12_data/sentences/t12.2022.05.05_sentences.mat'][-1], aspect='auto', origin='lower') | |
axs[2].set_title('aligned TTS mspec') | |
axs[2].set_ylabel("MFCC") | |
axs[2].set_xlabel("time (20ms)") | |
plt.tight_layout() | |
plt.show() | |
## | |
def resample_idx(idx, orig_sr, target_sr): | |
return int(idx * target_sr / orig_sr) | |
s = resample_idx(startIdx, 1000/20, 30000) | |
e = resample_idx(stopIdx, 1000/20, 30000) | |
t12_audio = audio_block[block_idx][s:e] | |
ex_mspec = mel_spectrogram(torch.tensor(t12_audio[None], dtype=torch.float32), | |
2048, 80, 30000, 30000//50, 30000//25, 0, 8000, center=False).squeeze() | |
fig, axs = plt.subplots(2, 1, figsize=(10, 6), sharey=True) | |
axs[0].imshow(t12_mspec, aspect='auto', origin='lower') | |
axs[0].set_title('t12_mspec') | |
axs[1].imshow(ex_mspec, aspect='auto', origin='lower') | |
axs[1].set_title('again from audio (should match)') | |
Audio(t12_audio, rate=30000) | |
## | |
Audio(tts_audio, rate=sample_rate) | |
## | |
# for sentence, utterances in tqdm(sentence_mapping.items()): | |
# for mat_file, idx in utterances: | |
# sentence_file = competition_file_mapping[mat_file] | |
# sentence_mat = mat_files[sentence_file] | |
# # map goCue index to block index & relative audio index | |
# t12_audio = sentence_mat['audio'][0,blockIdx][0] | |
# if speaking_modes[mat_file] == "vocalized": | |
# tts_audio, tts_phonemes, sample_rate = load_TTS_data(sentence) | |
# print(f"{t12_audio.shape=}, {tts_audio.shape=}") | |
# break | |
# dtw_path = DTW_between_audio_files(t12_audio, tts_audio) | |
# # This is a placeholder for how you might align phoneme labels. | |
# # You might need a more specific way to accumulate phoneme labels based on your data. | |
# aligned_phonemes = [tts_phonemes[i] for i in dtw_path] | |
# # TODO: accumulate the aligned_phonemes for each timestep of your neural data. | |
print("Processing complete!") | |
## | |
## | |
# use torch.cdist and align_from_distances for DTW | |
## | |
## | |
# look at some audio envelops for speakingMode 'speak' or 'mouthing' | |
mouthing_mat = None | |
speak_mat = None | |
for f in sentences_files: | |
# mat = scipy.io.loadmat(f) | |
mat = mat_files[f] | |
assert len(mat['speakingMode']) == 1, f"More than one speakingMode in file {f}" | |
if mat['speakingMode'][0] == 'attempted nonvocal speaking': | |
mouthing_mat = mat | |
elif mat['speakingMode'][0] == 'attempted speaking': | |
sp | |
eak_mat = mat | |
else: | |
raise Exception(f"Unknown speakingMode: {mat['speakingMode'][0]} fro file {f}") | |
# if mouthing_mat is not None and speak_mat is not None: | |
# break | |
## | |
# plot audio envelopes for mouthing and speaking | |
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5)) | |
mouthed_audio = librosa.util.buf_to_float(mouthing_mat['audio'][0,0][0]) | |
spoken_audio = librosa.util.buf_to_float(speak_mat['audio'][0,0][0]) | |
ax1.plot(librosa.resample(mouthed_audio, orig_sr=30000, target_sr=1000)) | |
ax2.plot(librosa.resample(spoken_audio, orig_sr=30000, target_sr=1000)) | |
ax1.set_title('mouthing') | |
ax2.set_title('speaking') | |
plt.show() | |
## | |
# mouthed Audio | |
# write_wav(os.path.join(T12_dir, "mouthing.wav"), 30000, mouthed_audio) | |
Audio(mouthed_audio, rate=30000) | |
## | |
# vocalized Audio | |
# write_wav(os.path.join(T12_dir, "mouthing.wav"), 30000, mouthed_audio) | |
Audio(spoken_audio, rate=30000) | |
## | |
# ============== Explore the data format ============== | |
# blockList starts at 0 or 5 or ...? | |
print(f"{mouthing_mat['blockList']=}\n{mouthing_mat['blockTypes']=}") | |
num_blocks = len(mouthing_mat['blockTypes']) | |
T = mouthing_mat['spikePow'].shape[0] | |
neural_seconds = T * 20 / 1000 | |
print(f"{T=}\n{neural_seconds=}") | |
# audio: B x 1 vector of raw audio snippets (B=number of blocks) | |
# Audio data was recorded at 30 kHz and is aligned to the neural data | |
# (it begins at the first time step of neural data for that block). | |
nAudio = np.sum([m[0].shape[0] for m in mouthing_mat['audio'][0]]) | |
audio_seconds = nAudio / 30000 | |
# audioFeatures: T x 42 matrix of MFCC features (T=number of 20 ms time steps). | |
# Can be used as a control to attempt to decode speech from audio features. | |
print(f"{nAudio=}\n{audio_seconds=}") | |
audio_seconds = 20 * nAudio / T / 1000 | |
audio_seconds | |
## | |
mouthing_mat['xpcClock'] | |
## | |
mouthing_mat['nsp1Clock'] | |
## | |
mouthing_mat['nsp2Clock'] | |
## | |
mouthing_mat['redisClock'] | |
## | |
block_start_idx = np.concatenate([[0], 1 + np.where(np.diff(mouthing_mat['blockNum'][:,0]))[0]]) | |
block_start_idx | |
## | |
mouthing_mat['delayTrialEpochs'] # eg [1, 152] | |
mouthing_mat['goTrialEpochs'] # eg [153, 345] | |
assert np.all(mouthing_mat['delayTrialEpochs'][:,1] + 1 == mouthing_mat['goTrialEpochs'][:,0]) | |
assert len(mouthing_mat['sentences']) == len(mouthing_mat['goTrialEpochs']) | |
## | |
trial_start = mouthing_mat['delayTrialEpochs'][:,0] | |
trial_end = mouthing_mat['goTrialEpochs'][:,1] | |
trial_duration = trial_end - trial_start | |
tot_time = trial_duration.sum() | |
tot_time / 50 # seconds | |
## | |
sentenceDurations | |
## | |
len(mouthing_mat['sentences']), mouthing_mat['audioFeatures'].shape, mouthing_mat['audioEnvelope'].shape, mouthing_mat['audio'].shape | |
## | |
mouthing_mat['spikePow'].shape | |
## | |
## | |
# For each sentence, read the audio envelope and run DTW. | |
## | |
# read all sentences into a list | |
@persist_to_file(os.path.join(os.path.dirname(datadir), "all_sentences.pkl")) | |
def read_all_sentences(sentences_files): | |
sentences = [] | |
for sentences_file in tqdm(sentences_files): | |
mat = scipy.io.loadmat(sentences_file) | |
for sentence in mat['sentences']: | |
sentences.append(str(sentence[0][0]).rstrip()) # strip whitespace at end | |
return sentences | |
all_sentences = read_all_sentences(sentences_files) | |
print('Number of sentences:', len(all_sentences)) | |
print('Unique sentences:', len(np.unique(all_sentences))) | |
## | |
def sentence_to_fn(sentence, directory=TTS_dir, ext=".wav"): | |
fn = re.sub(r'[^\w\s]', '', sentence) # remove punctuation | |
fn = fn.lower().replace(' ', '_') # lowercase with underscores | |
return os.path.join(directory, fn+ext) | |
# number of unique filenames should be equal to number of unique sentences. Print sentence pair if not. | |
uniq_sentences = np.unique(all_sentences) | |
filenames = defaultdict(list) | |
# halfway = len(uniq_sentences) // 2 | |
halfway = 6291 | |
for sentence in tqdm(uniq_sentences[1245:halfway]): | |
fn = sentence_to_fn(sentence) | |
filenames[fn].append(sentence) | |
toggle_phones = True | |
# loading data into CPU, should work but takes ~ 6-10 minutes: | |
datasets = list() | |
transform = TextTransform(toggle_phones=toggle_phones) | |
for files in [train_files, test_files]: | |
mat = [scipy.io.loadmat(x) for x in files] | |
X = [trl['tx1'] for trl in mat] | |
if toggle_phones: | |
y = [trl['phoneme_ints'][trl['phoneme_ints'] != 0] for trl in mat] | |
else: | |
y = [transform.text_to_int(trl['text'][0]) for trl in mat] | |
z = [session_mapping[f.split('sessions/')[1].split('/')[0]] for f in files] | |
datasets.append(SpeechDataset(X, y, z, toggle_phones=toggle_phones)) | |
del mat, X, y | |
trainset = datasets[0] | |
testset = datasets[1] | |
print(len(trainset)) | |
print(len(testset)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment