Skip to content

Instantly share code, notes, and snippets.

@peace098beat
Last active November 24, 2015 08:30
Show Gist options
  • Save peace098beat/e4206f9034e3d439b168 to your computer and use it in GitHub Desktop.
Save peace098beat/e4206f9034e3d439b168 to your computer and use it in GitHub Desktop.
[機械学習] オーディオデータの読み込み
# -*- coding: utf-8 -*-
"""GaborWavelet.py
ガボール変換を使ったウェーブレット変換
Version:
1.0 (copy)
Reference:
http://hp.vector.co.jp/authors/VA046927/gabor_wavelet/gabor_wavelet.html
http://criticaldays2.blogspot.jp/2014/04/blog-post_23.html
http://www.softist.com/programming/gabor-wavelet/gabor-wavelet.htm
"""
__version__ = '1.0'
import numpy as np
def gwt(audio_data, Fs):
def psi(a, b, a_t, sigma):
""" ガボール関数"""
t = (a_t - b) / a
g = 1. / (2 * np.sqrt(np.pi * sigma)) * np.exp(-1. * t ** 2 / (4. * sigma ** 2))
e = np.exp(1j * 2. * np.pi * t)
return g * e
def utili_sample(a, sigma, Vc):
samp = a * sigma * np.sqrt(-2. * np.log(Vc))
return samp
import time
start = time.time()
Fs = 1. * Fs
adata = audio_data
N = len(adata)
X = np.array(adata) * 1.
# -------------------
# ウェーブレット変換処理
# -------------------
"""
解析パラメータ
"""
# 1. ガボールウェーブレットパラメータ
sigma = 5
# 2. 周波数分割数
a_N = 512
# 解析周波数(最低周波数)
f_min = 0
# 解析周波数(最高周波数)
f_max = 24000
# 有効計算幅(小さいほど精度高い)
Vc = 0.00001
# ループ準備
# ---------
# 時間幅
# t = np.arange(0, N) / float(Fs)
t = np.linspace(0, N/Fs, N)
# 解析周波数
_fn = np.linspace(f_min, f_max, a_N+1)
fn = _fn[1:]
# 解析結果格納バッファ
Anadata = np.empty(shape=(N, a_N), dtype=complex)
print "-----------------------------"
print '== Gabor Wavelet Analysing =='
print '== ...'
for a_n in range(0, a_N):
a = 1. / fn[a_n]
b = 1. * N / 2. / Fs
Psi = psi(a, b, t, sigma)
# 実用領域のみ畳み込み
us = np.floor(utili_sample(a, sigma, Vc) * Fs)
if us < N:
# ss = np.floor((N - us) / 2)
# Psi = Psi[ss:ss + us]
ss = np.floor(N/2-us/2)
se = np.floor(N/2+us/2)
Psi = Psi[ss:se]
# 畳み込み (convolve)
# Anadata[:, a_n] = (1. / np.sqrt(a)) * np.convolve(Psi, X, 'same')
Anadata[:, a_n] = (1. / np.sqrt(a)) * np.convolve(X, Psi, 'same')
# 描画用に配列の左右を入れ替え
# Anadata = np.fliplr(Anadata)
# 解析時間
print("== Finish Analys :{0}".format(time.time() - start))
print "-----------------------------"
return Anadata, t, fn
#! coding:utf-8
"""
main2.py
(2015/11/24)
Created by 0160929 on 2015/11/20 9:05
"""
__version__ = '0.2'
import sys, os, glob
import time
import scipy.io.wavfile
import numpy as np
import matplotlib.pyplot as plt
# logを保存
from datetime import datetime
import csv
def debuglog(s=None, data=None):
d = datetime.now().isoformat()
with open('log.csv', 'a') as f:
writer = csv.writer(f, lineterminator='\n')
if not s is None:
writer.writerow([d, s])
print d, s
if not data is None:
writer.writerow([d, data])
print d, data
def create_fft(fn, override=True):
"""
wavファイルに対しFFTを行い.fft.npyとして保存
:param fn: wavファイルパス
:param override: 上書きオブション{True:上書き, False:処理無し}
"""
# -- ファイルパスからファイル名を取得 ---------------
base_fn, ext = os.path.splitext(fn)
# -- 保存するファイル名 ------------- ---------------
data_fn = base_fn + ".fft"
# -- 上書き確認
if os.path.exists(data_fn + ".npy"):
if override:
# ファイルがある場合は削除
os.remove(data_fn + ".npy")
else:
# ファイルがある場合は処理しない
return
# -- FFT処理 ----------------------------------------
fs, X = scipy.io.wavfile.read(fn)
fft_features = np.abs(scipy.fft(X[:, 0]))
# -- ファイル保存 -----------------------------------
np.save(data_fn, fft_features)
print '>> save fft ', os.path.basename(fn)
def read_fft(base_dir):
"""
FFT(.fft.npy)データの読み込み
:param base_dir:
"""
# ディレクトリ以下のFFTファイルを読み込み
wild_dir = os.path.join(base_dir, "*.fft.npy")
file_list = glob.glob(wild_dir)
# FFTファイルの読み込み
X = []
Y = []
for fn in file_list:
fft_features = np.load(fn)
X.append(fft_features)
Y.append(os.path.basename(fn))
return np.array(X), np.array(Y)
def main_fft():
# -- 解析対象となる.wavファイルの取得 -----------------------
# カレントパスの取得
BASE_DIR = os.path.dirname(sys.argv[0])
# ディレクトリを全て取得
# SOURCE_DIR = [file for file in glob.glob('./*') if os.path.isdir(file)]
SOURCE_DIR = "recdata"
# wavファイルのワイルドカード
WILD_EXT = "*.wav"
# -- wavファイルリストを取得 --------------------------------
# パスの生成
_search_path = os.path.join(BASE_DIR, SOURCE_DIR, WILD_EXT)
# パスを整形
search_path = os.path.normpath(_search_path)
# .wavファイルの一括取得
file_list = glob.glob(search_path)
# -- FFTデータ(.fft.npy)の作成・保存 ------------------------
for fn in file_list:
create_fft(fn, override=False)
# -- FFTデータ(.fft.npy)の読み込み --------------------------
X, Y = read_fft(os.path.join(BASE_DIR, SOURCE_DIR))
# -- 以下解析およびグラフ出力 -------------------------------
# X: 71x24001x2ch
print 'X:', X.shape
print Y, Y.shape
# 平均の算出
features_label = Y
features = X[:, :12000]
features_mean = np.mean(features, axis=0)
features_std_n = features_mean - np.std(features, axis=0) * 1
features_std_p = features_mean + np.std(features, axis=0) * 1
features_max = np.max(features, axis=0)
features_min = np.min(features, axis=0)
# グラフプロット
fig = plt.figure()
ax_log = fig.add_subplot(211)
ax_lin = fig.add_subplot(212)
# ax_log.set_title(Y[70])
# -- スペクトル(ログ)の表示
for i, v in enumerate(features_label):
ax_log.plot(20 * np.log10(features[i, :]), c='gray', linewidth=0.2)
ax_log.plot(20 * np.log10(features_mean), c='b', linewidth=2.5)
ax_log.plot(20 * np.log10(features_std_p), c='r', linewidth=0.5)
ax_log.plot(20 * np.log10(features_std_n), c='r', linewidth=0.5)
# 最小値プロット
# ax_log.plot( 20*np.log10(features_min), c='k', linewidth=0.5)
# ax_log.plot( 20*np.log10(features_max), c='k', linewidth=0.5)
ax_log.set_ylim([60, 120])
ax_log.grid(True)
ax_log.set_title(SOURCE_DIR)
# -- スペクトル(リニア)の表示
for i, v in enumerate(features_label):
ax_lin.plot(features[i, :], c='gray', linewidth=1.0)
ax_lin.plot(features_mean, c='b', linewidth=2.5)
# ax_lin.set_ylim([0,100000])
ax_lin.grid(True)
# グラフの保存
fig.savefig(SOURCE_DIR + '_fft')
plt.show()
from FiSig.gwt import gwt
def create_gwt(fn, override=False):
"""
wavファイルに対しGWTを行い.fft.npyとして保存
:param fn: wavファイルパス
:param override: 上書きオブション{True:上書き, False:処理無し}
"""
print '>> create_gwt() ...'
# -- ファイルパスからファイル名を取得 ---------------
base_fn, ext = os.path.splitext(fn)
# -- 保存するファイル名 ------------- ---------------
data_fn = base_fn + ".gwt"
# -- 上書き確認
if os.path.exists(data_fn + ".npy"):
if override:
os.remove(data_fn + ".npy") # ファイルがある場合は削除
else:
return # ファイルがある場合は処理しない
# -- FFT処理 ----------------------------------------
fs, X = scipy.io.wavfile.read(fn)
audio_data = X[:, 0] + X[:, 1]
gwt_features = np.abs(gwt(audio_data, fs))
# -- ファイル保存 -----------------------------------
np.save(data_fn, gwt_features)
print '>> save gwt ', os.path.basename(fn)
def read_gwt(base_dir):
"""
GWT(.gwt.npy)データの読み込み
:param base_dir: .gwt.npyが保管されているフォルダ
"""
print '>> read_gwt() ...'
# ディレクトリ以下のFFTファイルを読み込み
wild_dir = os.path.join(base_dir, "*.gwt.npy")
file_list = glob.glob(wild_dir)
# FFTファイルの読み込み
X = []
Y = []
s = '*'
for fn in file_list:
print s,
gwt_features = np.load(fn)
X.append(gwt_features)
Y.append(os.path.basename(fn))
return np.array(X), np.array(Y)
def main_gwt():
"""
GWTを行いデータを集計解析する。
:return:
"""
# -- 解析対象となる.wavファイルの取得 -----------------------
# カレントパスの取得
BASE_DIR = os.path.dirname(sys.argv[0])
# ディレクトリを全て取得
# SOURCE_DIR = [file for file in glob.glob('./*') if os.path.isdir(file)]
SOURCE_DIR = "recdata"
# wavファイルのワイルドカード
WILD_EXT = "*.wav"
# -- wavファイルリストを取得 --------------------------------
_search_path = os.path.join(BASE_DIR, SOURCE_DIR, WILD_EXT) # パスの生成
search_path = os.path.normpath(_search_path) # パスを整形
# .wavファイルの一括取得
file_list = glob.glob(search_path)
# -- FFTデータ(.fft.npy)の作成・保存 ------------------------
# TODO: gwtの時間・周波数軸データの保存を追加
start = time.time()
for fn in file_list:
create_gwt(fn, override=False)
print(">>> ... create_gwt Fin! :{0}".format(time.time() - start))
# -- FFTデータ(.fft.npy)の読み込み --------------------------
start = time.time()
X, Y = read_gwt(os.path.join(BASE_DIR, SOURCE_DIR))
print(">>> ... read_gwt Fin! :{0}".format(time.time() - start))
# -- 以下解析およびグラフ出力 -------------------------------
# -- 定数 ------------------------
fs = 48000
start_ms = 80
end_ms = 190
ms2smp = 1 / 1000. * fs
start_smp = start_ms * ms2smp
end_smp = end_ms * ms2smp
deltaT = 3
delta_smp = deltaT * ms2smp
ss = 100 * ms2smp
# -- 時間平均化 ------------------
gwt_mean_series = []
for k in range(X.shape[0]):
gwt, t, f = X[k, :] # データの呼び出し
gwt_mean = np.mean(gwt[ss:ss + delta_smp, :], axis=0)
gwt_mean_series.append(gwt_mean)
gwt_mean_series = np.asarray(gwt_mean_series)
print gwt_mean_series.shape
gwt_mean_mean = np.mean(gwt_mean_series, axis=0)
fig1 = plt.figure()
ax1 = fig1.add_subplot(2, 1, 1)
ax2 = fig1.add_subplot(2, 1, 2)
gwt_series = []
for k in range(X.shape[0]):
plt.axes(ax2)
plt.plot(f, gwt_mean_series[k, :], c='gray')
plt.grid(True)
plt.axes(ax1)
plt.plot(f, 20 * np.log10(gwt_mean_series[k, :]), c='gray')
plt.grid(True)
plt.axes(ax2)
plt.plot(f, gwt_mean_mean, c='blue', linewidth=3)
plt.axes(ax1)
plt.plot(f, 20 * np.log10(gwt_mean_mean), c='blue', linewidth=3)
from scipy import signal
PEAK_FIND_ORDER = 5
Nx = gwt_mean_mean.shape[0]
Search_nx = np.round(Nx * 3. / 5.)
# 極大値検索
maxId, = signal.argrelmax(gwt_mean_mean[:Search_nx], order=PEAK_FIND_ORDER)
print maxId
for id in maxId:
s = '%0.1f' % (f[id]/1000.)
plt.axes(ax2)
plt.annotate(s, xy=(f[id], gwt_mean_mean[id]), fontsize=9,
horizontalalignment='center', verticalalignment='bottom')
# -- GWTのグラフ表示 OK
# fig = plt.figure()
# ax = fig.add_subplot(1, 1, 1)
#
# gdata = 20 * np.log10(gwt[start_smp:end_smp, :].T)
# extent = t[start_smp], t[end_smp], f[0], f[-1]
# plt.imshow(gdata, cmap='jet', extent=extent, origin='lower', interpolation='nearest')
# plt.axis('tight')
# plt.xlabel('time[s]')
# plt.ylabel('frequency[Hz]')
plt.axes(ax1)
plt.title(SOURCE_DIR)
plt.savefig(SOURCE_DIR)
plt.show()
if __name__ == '__main__':
main_gwt()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment