Last active
November 24, 2015 08:30
-
-
Save peace098beat/e4206f9034e3d439b168 to your computer and use it in GitHub Desktop.
[機械学習] オーディオデータの読み込み
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # -*- coding: utf-8 -*- | |
| """GaborWavelet.py | |
| ガボール変換を使ったウェーブレット変換 | |
| Version: | |
| 1.0 (copy) | |
| Reference: | |
| http://hp.vector.co.jp/authors/VA046927/gabor_wavelet/gabor_wavelet.html | |
| http://criticaldays2.blogspot.jp/2014/04/blog-post_23.html | |
| http://www.softist.com/programming/gabor-wavelet/gabor-wavelet.htm | |
| """ | |
| __version__ = '1.0' | |
| import numpy as np | |
| def gwt(audio_data, Fs): | |
| def psi(a, b, a_t, sigma): | |
| """ ガボール関数""" | |
| t = (a_t - b) / a | |
| g = 1. / (2 * np.sqrt(np.pi * sigma)) * np.exp(-1. * t ** 2 / (4. * sigma ** 2)) | |
| e = np.exp(1j * 2. * np.pi * t) | |
| return g * e | |
| def utili_sample(a, sigma, Vc): | |
| samp = a * sigma * np.sqrt(-2. * np.log(Vc)) | |
| return samp | |
| import time | |
| start = time.time() | |
| Fs = 1. * Fs | |
| adata = audio_data | |
| N = len(adata) | |
| X = np.array(adata) * 1. | |
| # ------------------- | |
| # ウェーブレット変換処理 | |
| # ------------------- | |
| """ | |
| 解析パラメータ | |
| """ | |
| # 1. ガボールウェーブレットパラメータ | |
| sigma = 5 | |
| # 2. 周波数分割数 | |
| a_N = 512 | |
| # 解析周波数(最低周波数) | |
| f_min = 0 | |
| # 解析周波数(最高周波数) | |
| f_max = 24000 | |
| # 有効計算幅(小さいほど精度高い) | |
| Vc = 0.00001 | |
| # ループ準備 | |
| # --------- | |
| # 時間幅 | |
| # t = np.arange(0, N) / float(Fs) | |
| t = np.linspace(0, N/Fs, N) | |
| # 解析周波数 | |
| _fn = np.linspace(f_min, f_max, a_N+1) | |
| fn = _fn[1:] | |
| # 解析結果格納バッファ | |
| Anadata = np.empty(shape=(N, a_N), dtype=complex) | |
| print "-----------------------------" | |
| print '== Gabor Wavelet Analysing ==' | |
| print '== ...' | |
| for a_n in range(0, a_N): | |
| a = 1. / fn[a_n] | |
| b = 1. * N / 2. / Fs | |
| Psi = psi(a, b, t, sigma) | |
| # 実用領域のみ畳み込み | |
| us = np.floor(utili_sample(a, sigma, Vc) * Fs) | |
| if us < N: | |
| # ss = np.floor((N - us) / 2) | |
| # Psi = Psi[ss:ss + us] | |
| ss = np.floor(N/2-us/2) | |
| se = np.floor(N/2+us/2) | |
| Psi = Psi[ss:se] | |
| # 畳み込み (convolve) | |
| # Anadata[:, a_n] = (1. / np.sqrt(a)) * np.convolve(Psi, X, 'same') | |
| Anadata[:, a_n] = (1. / np.sqrt(a)) * np.convolve(X, Psi, 'same') | |
| # 描画用に配列の左右を入れ替え | |
| # Anadata = np.fliplr(Anadata) | |
| # 解析時間 | |
| print("== Finish Analys :{0}".format(time.time() - start)) | |
| print "-----------------------------" | |
| return Anadata, t, fn |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #! coding:utf-8 | |
| """ | |
| main2.py | |
| (2015/11/24) | |
| Created by 0160929 on 2015/11/20 9:05 | |
| """ | |
| __version__ = '0.2' | |
| import sys, os, glob | |
| import time | |
| import scipy.io.wavfile | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| # logを保存 | |
| from datetime import datetime | |
| import csv | |
| def debuglog(s=None, data=None): | |
| d = datetime.now().isoformat() | |
| with open('log.csv', 'a') as f: | |
| writer = csv.writer(f, lineterminator='\n') | |
| if not s is None: | |
| writer.writerow([d, s]) | |
| print d, s | |
| if not data is None: | |
| writer.writerow([d, data]) | |
| print d, data | |
| def create_fft(fn, override=True): | |
| """ | |
| wavファイルに対しFFTを行い.fft.npyとして保存 | |
| :param fn: wavファイルパス | |
| :param override: 上書きオブション{True:上書き, False:処理無し} | |
| """ | |
| # -- ファイルパスからファイル名を取得 --------------- | |
| base_fn, ext = os.path.splitext(fn) | |
| # -- 保存するファイル名 ------------- --------------- | |
| data_fn = base_fn + ".fft" | |
| # -- 上書き確認 | |
| if os.path.exists(data_fn + ".npy"): | |
| if override: | |
| # ファイルがある場合は削除 | |
| os.remove(data_fn + ".npy") | |
| else: | |
| # ファイルがある場合は処理しない | |
| return | |
| # -- FFT処理 ---------------------------------------- | |
| fs, X = scipy.io.wavfile.read(fn) | |
| fft_features = np.abs(scipy.fft(X[:, 0])) | |
| # -- ファイル保存 ----------------------------------- | |
| np.save(data_fn, fft_features) | |
| print '>> save fft ', os.path.basename(fn) | |
| def read_fft(base_dir): | |
| """ | |
| FFT(.fft.npy)データの読み込み | |
| :param base_dir: | |
| """ | |
| # ディレクトリ以下のFFTファイルを読み込み | |
| wild_dir = os.path.join(base_dir, "*.fft.npy") | |
| file_list = glob.glob(wild_dir) | |
| # FFTファイルの読み込み | |
| X = [] | |
| Y = [] | |
| for fn in file_list: | |
| fft_features = np.load(fn) | |
| X.append(fft_features) | |
| Y.append(os.path.basename(fn)) | |
| return np.array(X), np.array(Y) | |
| def main_fft(): | |
| # -- 解析対象となる.wavファイルの取得 ----------------------- | |
| # カレントパスの取得 | |
| BASE_DIR = os.path.dirname(sys.argv[0]) | |
| # ディレクトリを全て取得 | |
| # SOURCE_DIR = [file for file in glob.glob('./*') if os.path.isdir(file)] | |
| SOURCE_DIR = "recdata" | |
| # wavファイルのワイルドカード | |
| WILD_EXT = "*.wav" | |
| # -- wavファイルリストを取得 -------------------------------- | |
| # パスの生成 | |
| _search_path = os.path.join(BASE_DIR, SOURCE_DIR, WILD_EXT) | |
| # パスを整形 | |
| search_path = os.path.normpath(_search_path) | |
| # .wavファイルの一括取得 | |
| file_list = glob.glob(search_path) | |
| # -- FFTデータ(.fft.npy)の作成・保存 ------------------------ | |
| for fn in file_list: | |
| create_fft(fn, override=False) | |
| # -- FFTデータ(.fft.npy)の読み込み -------------------------- | |
| X, Y = read_fft(os.path.join(BASE_DIR, SOURCE_DIR)) | |
| # -- 以下解析およびグラフ出力 ------------------------------- | |
| # X: 71x24001x2ch | |
| print 'X:', X.shape | |
| print Y, Y.shape | |
| # 平均の算出 | |
| features_label = Y | |
| features = X[:, :12000] | |
| features_mean = np.mean(features, axis=0) | |
| features_std_n = features_mean - np.std(features, axis=0) * 1 | |
| features_std_p = features_mean + np.std(features, axis=0) * 1 | |
| features_max = np.max(features, axis=0) | |
| features_min = np.min(features, axis=0) | |
| # グラフプロット | |
| fig = plt.figure() | |
| ax_log = fig.add_subplot(211) | |
| ax_lin = fig.add_subplot(212) | |
| # ax_log.set_title(Y[70]) | |
| # -- スペクトル(ログ)の表示 | |
| for i, v in enumerate(features_label): | |
| ax_log.plot(20 * np.log10(features[i, :]), c='gray', linewidth=0.2) | |
| ax_log.plot(20 * np.log10(features_mean), c='b', linewidth=2.5) | |
| ax_log.plot(20 * np.log10(features_std_p), c='r', linewidth=0.5) | |
| ax_log.plot(20 * np.log10(features_std_n), c='r', linewidth=0.5) | |
| # 最小値プロット | |
| # ax_log.plot( 20*np.log10(features_min), c='k', linewidth=0.5) | |
| # ax_log.plot( 20*np.log10(features_max), c='k', linewidth=0.5) | |
| ax_log.set_ylim([60, 120]) | |
| ax_log.grid(True) | |
| ax_log.set_title(SOURCE_DIR) | |
| # -- スペクトル(リニア)の表示 | |
| for i, v in enumerate(features_label): | |
| ax_lin.plot(features[i, :], c='gray', linewidth=1.0) | |
| ax_lin.plot(features_mean, c='b', linewidth=2.5) | |
| # ax_lin.set_ylim([0,100000]) | |
| ax_lin.grid(True) | |
| # グラフの保存 | |
| fig.savefig(SOURCE_DIR + '_fft') | |
| plt.show() | |
| from FiSig.gwt import gwt | |
| def create_gwt(fn, override=False): | |
| """ | |
| wavファイルに対しGWTを行い.fft.npyとして保存 | |
| :param fn: wavファイルパス | |
| :param override: 上書きオブション{True:上書き, False:処理無し} | |
| """ | |
| print '>> create_gwt() ...' | |
| # -- ファイルパスからファイル名を取得 --------------- | |
| base_fn, ext = os.path.splitext(fn) | |
| # -- 保存するファイル名 ------------- --------------- | |
| data_fn = base_fn + ".gwt" | |
| # -- 上書き確認 | |
| if os.path.exists(data_fn + ".npy"): | |
| if override: | |
| os.remove(data_fn + ".npy") # ファイルがある場合は削除 | |
| else: | |
| return # ファイルがある場合は処理しない | |
| # -- FFT処理 ---------------------------------------- | |
| fs, X = scipy.io.wavfile.read(fn) | |
| audio_data = X[:, 0] + X[:, 1] | |
| gwt_features = np.abs(gwt(audio_data, fs)) | |
| # -- ファイル保存 ----------------------------------- | |
| np.save(data_fn, gwt_features) | |
| print '>> save gwt ', os.path.basename(fn) | |
| def read_gwt(base_dir): | |
| """ | |
| GWT(.gwt.npy)データの読み込み | |
| :param base_dir: .gwt.npyが保管されているフォルダ | |
| """ | |
| print '>> read_gwt() ...' | |
| # ディレクトリ以下のFFTファイルを読み込み | |
| wild_dir = os.path.join(base_dir, "*.gwt.npy") | |
| file_list = glob.glob(wild_dir) | |
| # FFTファイルの読み込み | |
| X = [] | |
| Y = [] | |
| s = '*' | |
| for fn in file_list: | |
| print s, | |
| gwt_features = np.load(fn) | |
| X.append(gwt_features) | |
| Y.append(os.path.basename(fn)) | |
| return np.array(X), np.array(Y) | |
| def main_gwt(): | |
| """ | |
| GWTを行いデータを集計解析する。 | |
| :return: | |
| """ | |
| # -- 解析対象となる.wavファイルの取得 ----------------------- | |
| # カレントパスの取得 | |
| BASE_DIR = os.path.dirname(sys.argv[0]) | |
| # ディレクトリを全て取得 | |
| # SOURCE_DIR = [file for file in glob.glob('./*') if os.path.isdir(file)] | |
| SOURCE_DIR = "recdata" | |
| # wavファイルのワイルドカード | |
| WILD_EXT = "*.wav" | |
| # -- wavファイルリストを取得 -------------------------------- | |
| _search_path = os.path.join(BASE_DIR, SOURCE_DIR, WILD_EXT) # パスの生成 | |
| search_path = os.path.normpath(_search_path) # パスを整形 | |
| # .wavファイルの一括取得 | |
| file_list = glob.glob(search_path) | |
| # -- FFTデータ(.fft.npy)の作成・保存 ------------------------ | |
| # TODO: gwtの時間・周波数軸データの保存を追加 | |
| start = time.time() | |
| for fn in file_list: | |
| create_gwt(fn, override=False) | |
| print(">>> ... create_gwt Fin! :{0}".format(time.time() - start)) | |
| # -- FFTデータ(.fft.npy)の読み込み -------------------------- | |
| start = time.time() | |
| X, Y = read_gwt(os.path.join(BASE_DIR, SOURCE_DIR)) | |
| print(">>> ... read_gwt Fin! :{0}".format(time.time() - start)) | |
| # -- 以下解析およびグラフ出力 ------------------------------- | |
| # -- 定数 ------------------------ | |
| fs = 48000 | |
| start_ms = 80 | |
| end_ms = 190 | |
| ms2smp = 1 / 1000. * fs | |
| start_smp = start_ms * ms2smp | |
| end_smp = end_ms * ms2smp | |
| deltaT = 3 | |
| delta_smp = deltaT * ms2smp | |
| ss = 100 * ms2smp | |
| # -- 時間平均化 ------------------ | |
| gwt_mean_series = [] | |
| for k in range(X.shape[0]): | |
| gwt, t, f = X[k, :] # データの呼び出し | |
| gwt_mean = np.mean(gwt[ss:ss + delta_smp, :], axis=0) | |
| gwt_mean_series.append(gwt_mean) | |
| gwt_mean_series = np.asarray(gwt_mean_series) | |
| print gwt_mean_series.shape | |
| gwt_mean_mean = np.mean(gwt_mean_series, axis=0) | |
| fig1 = plt.figure() | |
| ax1 = fig1.add_subplot(2, 1, 1) | |
| ax2 = fig1.add_subplot(2, 1, 2) | |
| gwt_series = [] | |
| for k in range(X.shape[0]): | |
| plt.axes(ax2) | |
| plt.plot(f, gwt_mean_series[k, :], c='gray') | |
| plt.grid(True) | |
| plt.axes(ax1) | |
| plt.plot(f, 20 * np.log10(gwt_mean_series[k, :]), c='gray') | |
| plt.grid(True) | |
| plt.axes(ax2) | |
| plt.plot(f, gwt_mean_mean, c='blue', linewidth=3) | |
| plt.axes(ax1) | |
| plt.plot(f, 20 * np.log10(gwt_mean_mean), c='blue', linewidth=3) | |
| from scipy import signal | |
| PEAK_FIND_ORDER = 5 | |
| Nx = gwt_mean_mean.shape[0] | |
| Search_nx = np.round(Nx * 3. / 5.) | |
| # 極大値検索 | |
| maxId, = signal.argrelmax(gwt_mean_mean[:Search_nx], order=PEAK_FIND_ORDER) | |
| print maxId | |
| for id in maxId: | |
| s = '%0.1f' % (f[id]/1000.) | |
| plt.axes(ax2) | |
| plt.annotate(s, xy=(f[id], gwt_mean_mean[id]), fontsize=9, | |
| horizontalalignment='center', verticalalignment='bottom') | |
| # -- GWTのグラフ表示 OK | |
| # fig = plt.figure() | |
| # ax = fig.add_subplot(1, 1, 1) | |
| # | |
| # gdata = 20 * np.log10(gwt[start_smp:end_smp, :].T) | |
| # extent = t[start_smp], t[end_smp], f[0], f[-1] | |
| # plt.imshow(gdata, cmap='jet', extent=extent, origin='lower', interpolation='nearest') | |
| # plt.axis('tight') | |
| # plt.xlabel('time[s]') | |
| # plt.ylabel('frequency[Hz]') | |
| plt.axes(ax1) | |
| plt.title(SOURCE_DIR) | |
| plt.savefig(SOURCE_DIR) | |
| plt.show() | |
| if __name__ == '__main__': | |
| main_gwt() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment