Created
March 11, 2017 07:01
-
-
Save bluenex/16cddcd1638f6da5a437d0e8845d460d to your computer and use it in GitHub Desktop.
general steps for EMG data loading & parsing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Data collection scheme\n", | |
"\n", | |
"We planned to collect 20 actions of controlling Gyro-roller which are labelled as `rep` (values from 0-19) in the parsed data. Those 20 actions are:\n", | |
"\n", | |
"- rep 1: **Baseline recording 1**\n", | |
"- rep 2: **Baseline recording 2**\n", | |
"- rep 3: **CENTER YAW OFF**\n", | |
"- rep 4: **CENTER YAW ON**\n", | |
"- rep 5: **CENTER ROLL OFF**\n", | |
"- rep 6: **CENTER ROLL ON**\n", | |
"- rep 7: **CENTER BOTH OFF**\n", | |
"- rep 8: **CENTER BOTH ON**\n", | |
"- rep 9: **LEFT YAW OFF**\n", | |
"- rep 10: **LEFT YAW ON**\n", | |
"- rep 11: **LEFT ROLL OFF**\n", | |
"- rep 12: **LEFT ROLL ON**\n", | |
"- rep 13: **LEFT BOTH OFF**\n", | |
"- rep 14: **LEFT BOTH ON**\n", | |
"- rep 15: **RIGHT YAW OFF**\n", | |
"- rep 16: **RIGHT YAW ON**\n", | |
"- rep 17: **RIGHT ROLL OFF**\n", | |
"- rep 18: **RIGHT ROLL ON**\n", | |
"- rep 19: **RIGHT BOTH OFF**\n", | |
"- rep 20: **RIGHT BOTH ON**\n", | |
"\n", | |
"As there are only 4 channels of EMG sensor available. In order to collect signal of five muscles of both side is a bit tricky. We separated collecting to be 3 rounds:\n", | |
"\n", | |
"1. **EXTENSOR & FLEXOR**\n", | |
"2. **BICEPS & TRICEPS**\n", | |
"3. **DELTOID**\n", | |
"\n", | |
"As a result, there are 60 files in total, and can break down information from file name like so:\n", | |
"\n", | |
"`kai1_Plot_and_Store_Rep_6.27.csv`\n", | |
"\n", | |
"- kai1 = kai round 1 = **ex & flex**\n", | |
"- Rep_6 = **CENTER ROLL ON** action\n", | |
"\n", | |
"`kai3_Plot_and_Store_Rep_8.22.csv`\n", | |
"\n", | |
"- kai3 = kai round 3 = **del**\n", | |
"- Rep_8 = **CENTER BOTH ON** action" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Class to manipulate data\n", | |
"\n", | |
"There are three classes here as follow:\n", | |
"- EMGfiles\n", | |
"- EMGRound\n", | |
"- EMGData\n", | |
"\n", | |
"----\n", | |
"\n", | |
"### EMGfiles\n", | |
"EMGfiles class gets path and collects all available data files in the path. This list of data files will be used to load data.\n", | |
"\n", | |
"----\n", | |
"\n", | |
"### EMGRound\n", | |
"This class collects EMG data and some processing methods to pre-process data. However, this class is already included in `EMGData` class.\n", | |
"\n", | |
"#### Methods\n", | |
"- `getAvg(wheelpos)` - get average values of data at wheel pos <'CENTER', 'LEFT', 'RIGHT'>\n", | |
"- `getColumnNames()` - get column names\n", | |
"- `getDataRootPath()` - get data root path\n", | |
"- `getRMS()` - get RMS values of data\n", | |
"- `getSamplingRate()` - get sampling rate of collection\n", | |
"- `plotBig()` - plot difference between low pass filtered data and non-filtered as big figures\n", | |
"- `plotRMS()` - plot rms val as a bunch\n", | |
"- `plotDiff()` - plot diff of avg values between gyro ON and OFF\n", | |
"\n", | |
"----\n", | |
"\n", | |
"### EMGData\n", | |
"Since data is collected separately as multiple rounds for each action, EMGData is to store those 3 rounds together in the same class." | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Usage\n", | |
"- Run import, signal and EMG classes cells\n", | |
"- Define data path" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Import" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"import os\n", | |
"import pandas as pd\n", | |
"import numpy as np\n", | |
"import matplotlib.pyplot as plt\n", | |
"import itertools\n", | |
"from IPython.display import display\n", | |
"from scipy.signal import butter, freqz, filtfilt, decimate, detrend\n", | |
"%matplotlib inline" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Signal Classes" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"class sigfilter():\n", | |
" # function to use with map\n", | |
" def bbf(data, lowcut=10, highcut=500, fs=1024, order=7):\n", | |
" return butter_bandpass_filter(data, lowcut=lowcut, highcut=highcut, fs=fs, order=order)\n", | |
"\n", | |
" def bnf(data, cutoff=50, fs=1024):\n", | |
" return butter_notch_filter(data, cutoff=cutoff, fs=fs, order=2)\n", | |
"\n", | |
" def lpf(data, cutoff=230, fs=1024):\n", | |
" return butter_lowpass_filter(data, cutoff=cutoff, fs=fs, order=2)\n", | |
"\n", | |
" def butter_notch(lowcut, highcut, fs, order=1):\n", | |
" nyq = 0.5 * fs\n", | |
" low = lowcut / nyq\n", | |
" high = highcut / nyq\n", | |
" b, a = butter(order, [low, high], btype='bandstop')\n", | |
" return b, a\n", | |
"\n", | |
" def butter_notch_filter(data, cutoff=None, fs=None, order=1):\n", | |
" lowcut = cutoff-1\n", | |
" highcut = cutoff+1\n", | |
" b, a = butter_notch(lowcut, highcut, fs, order=order)\n", | |
" y = filtfilt(b, a, data)\n", | |
" return y\n", | |
"\n", | |
" def butter_lowpass(cutoff, fs, order=5):\n", | |
" nyq = 0.5 * fs\n", | |
" normal_cutoff = cutoff / nyq\n", | |
" b, a = butter(order, normal_cutoff, btype='low', analog=False)\n", | |
" return b, a\n", | |
"\n", | |
" def butter_lowpass_filter(data, cutoff=None, fs=None, order=5):\n", | |
" b, a = butter_lowpass(cutoff, fs, order=order)\n", | |
" y = filtfilt(b, a, data)\n", | |
" return y\n", | |
"\n", | |
" def butter_bandpass(lowcut, highcut, fs, order=5):\n", | |
" nyq = 0.5 * fs\n", | |
" low = lowcut / nyq\n", | |
" high = highcut / nyq\n", | |
" b, a = butter(order, [low, high], btype='bandpass')\n", | |
" return b, a\n", | |
"\n", | |
" def butter_bandpass_filter(data, lowcut=None, highcut=None, fs=None, order=5):\n", | |
" b, a = butter_bandpass(lowcut, highcut, fs, order=order)\n", | |
" y = filtfilt(b, a, data)\n", | |
" return y\n", | |
"\n", | |
" def clean_signal(emg, filtype='nb', q=10, cutoff=30, lowcut=10, highcut=500, fs=1024, order=4):\n", | |
" \"\"\"\n", | |
" clean signal does\n", | |
" DETREND\n", | |
" DECIMATE - down sampling\n", | |
" FILTER - by notch, lowpass or bandpass (put in option as string of first filter like 'nlb')\n", | |
" \"\"\"\n", | |
" emg_detrend = np.abs(detrend(emg))\n", | |
" emg_decimate = decimate(emg_detrend, q) # downsample for 10 times\n", | |
" emg_filtered = emg_decimate\n", | |
" if 'l' in filtype:\n", | |
" emg_filtered = butter_lowpass_filter(emg_filtered, cutoff=highcut, fs=fs, order=order) # low-pass filter\n", | |
" if 'b' in filtype:\n", | |
" emg_filtered = butter_bandpass_filter(emg_filtered, lowcut=lowcut, highcut=highcut, fs=fs, order=order)\n", | |
" if 'n' in filtype:\n", | |
" emg_filtered = butter_notch_filter(emg_filtered, cutoff=cutoff, fs=fs, order=order)\n", | |
" return emg_filtered[100::]\n", | |
"\n", | |
" def compute_power_spectrum(emg, n=512):\n", | |
" return np.abs(np.fft.fft(emg, n=n))**2\n", | |
"\n", | |
" def rescale(emg):\n", | |
" emg = emg - np.min(emg)\n", | |
" emg_rescale = emg/np.max(emg)\n", | |
" return emg_rescale\n", | |
"\n", | |
" def check_response(ftype, low, order=1, high=500, fs=1024):\n", | |
" \"\"\"\n", | |
" require freqz, filtfilt, butter from scipy.signal\n", | |
" \"\"\"\n", | |
" if ftype == \"notch\":\n", | |
" low = low-1\n", | |
" high = low+1\n", | |
" b, a = butter_notch(low, high, fs, order=1)\n", | |
" elif ftype == \"bandpass\":\n", | |
" b, a = butter_bandpass(low, high, fs, order=order)\n", | |
"\n", | |
" w, h = freqz(b, a, worN=512)\n", | |
"\n", | |
" return w*fs/6, np.abs(h)\n", | |
"\n", | |
" def plotfft(data, fs=1024):\n", | |
" \"\"\"\n", | |
" require np\n", | |
" \"\"\"\n", | |
" N = data.size\n", | |
" T = 1.0 / fs\n", | |
" x = np.linspace(0.0, N*T, N)\n", | |
"\n", | |
" yf = np.fft.fft(data)\n", | |
" xf = np.fft.fftfreq(N, T)\n", | |
" xf = np.fft.fftshift(xf)\n", | |
"\n", | |
" yplot = np.fft.fftshift(yf)\n", | |
"\n", | |
" return xf[N/2:-1], (1.0/N * np.abs(yplot))[N/2:-1]\n", | |
"\n", | |
" def rms_o(data, win_size, overlapping):\n", | |
" \"\"\"\n", | |
" calculate root mean square of the provided array\n", | |
" \"\"\"\n", | |
" # if data is not npnarray, convert it \n", | |
" if isinstance(data, (np.ndarray, np.generic)):\n", | |
" pass\n", | |
" else:\n", | |
" data = np.array(data)\n", | |
"\n", | |
" # if data is only 1d, convert it\n", | |
" if data.ndim == 1:\n", | |
" data = np.atleast_2d(data).T\n", | |
"\n", | |
" data_size = data.shape[0]\n", | |
" step = overlapping\n", | |
" fa = []\n", | |
"\n", | |
" for i in range(0,data_size,int(win_size*step)):\n", | |
" pw = np.power(data[i:i+win_size, :], 2)\n", | |
" # check if chopped is the same size as window size\n", | |
" if pw.shape[0] == win_size:\n", | |
" sq = np.sqrt(np.sum(pw, 0)/win_size)\n", | |
" fa.append(sq)\n", | |
"\n", | |
" fa = np.vstack(fa)\n", | |
"\n", | |
" return fa\n", | |
"\n", | |
" def waveform_o(data, win_size, overlapping):\n", | |
" \"\"\"\n", | |
" calculate wave form length of the provided array\n", | |
" \"\"\"\n", | |
" # if data is not npnarray, convert it\n", | |
" if isinstance(data, (np.ndarray, np.generic)):\n", | |
" pass\n", | |
" else:\n", | |
" data = np.array(data)\n", | |
"\n", | |
" # if data is only 1d, convert it\n", | |
" if data.ndim == 1:\n", | |
" data = np.atleast_2d(data).T\n", | |
"\n", | |
" data_size = data.shape[0]\n", | |
" step = overlapping\n", | |
" fa = [] \n", | |
"\n", | |
" for i in range(0,data_size,int(win_size*step)):\n", | |
" if data[i+1:i+1+win_size, :].shape[0] == data[i:i+win_size, :].shape[0]:\n", | |
" temp = np.sum(np.abs(data[i+1:i+1+win_size, :] - data[i:i+win_size, :]), axis=0)\n", | |
" fa.append(temp)\n", | |
"\n", | |
" fa = np.vstack(fa)\n", | |
"\n", | |
" return fa\n", | |
"\n", | |
" def rms(data, win_size, step):\n", | |
" \"\"\"\n", | |
" calculate root mean square of the provided array\n", | |
" \"\"\"\n", | |
" data_size = data.shape[0]\n", | |
" idxs = xrange(0, data_size, int(win_size*(1-step)))\n", | |
"\n", | |
" def rms_window(idx):\n", | |
" pw = np.power(data[idx:idx+win_size, :], 2)\n", | |
" sq = np.sqrt(np.sum(pw, 0)/win_size)\n", | |
" return sq\n", | |
"\n", | |
" return map(rms_window, idxs)\n", | |
"\n", | |
" def waveform(data, win_size, step):\n", | |
" \"\"\"\n", | |
" calculate wave form length of the provided array\n", | |
" \"\"\"\n", | |
" data_size = data.shape[0]\n", | |
" idxs = xrange(0, data_size, int(win_size*(1-step)))\n", | |
"\n", | |
" def waveform_window(idx):\n", | |
" temp = np.sum(np.abs(data[idx+1:idx+1+win_size, :] - data[idx:idx+win_size, :]), axis=0)\n", | |
" return temp\n", | |
"\n", | |
" return map(waveform_window, idxs)\n", | |
"\n", | |
" def rms_vectorized(data, win_size=100, step=0.5):\n", | |
" \"\"\"\n", | |
" calculate root mean square of the provided array\n", | |
" \"\"\"\n", | |
" data_size = data.shape[0]\n", | |
" step_size = int(win_size * step)\n", | |
" nframes = np.floor((data_size-step_size)/(1.0*step_size)).astype(np.int32) # Consider padding\n", | |
" idxs = np.tile(np.arange(0, win_size), (nframes, 1)) + np.arange(0, nframes*step_size, step_size)[:, None]\n", | |
" return np.sqrt(np.sum(data[idxs]**2, 1)/win_size)\n", | |
"\n", | |
" def rms_vectorized_zeromean(data, win_size=100, step=0.5):\n", | |
" \"\"\"\n", | |
" calculate root mean square of the provided array\n", | |
" \"\"\"\n", | |
" data_rect = np.abs(data-np.mean(data, axis=0))\n", | |
" data_size = data.shape[0]\n", | |
" step_size = int(win_size * step)\n", | |
" nframes = np.floor((data_size-step_size)/(1.0*step_size)).astype(np.int32) # Consider padding\n", | |
" idxs = np.tile(np.arange(0, win_size), (nframes, 1)) + np.arange(0, nframes*step_size, step_size)[:, None]\n", | |
" return np.sqrt(np.sum(data_rect[idxs]**2, 1)/win_size)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## EMG Classes" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"class EMGfiles():\n", | |
" '''\n", | |
" Class to manage emg data file paths\n", | |
" When instantiate object, it collects all available file paths as dictionary\n", | |
" with each subject as key.\n", | |
" '''\n", | |
"\n", | |
" def __init__(self, data_path):\n", | |
" self.all_files = {}\n", | |
"\n", | |
" # in case use data in parent dir\n", | |
" # p = os.path.join('.', os.pardir)\n", | |
"\n", | |
" ## data in the same dir\n", | |
" # p = os.curdir\n", | |
" ## data in OneDrive (lab pc vaio)\n", | |
" # p = '/Users/tulakan/OneDrive/BME/Rehabilitation Project/EMG data/EMG Gyro DelSys 2016 (V3)/'\n", | |
" ## data in OneDrive (lab pc OTALBS1)\n", | |
" # p = 'D:/OneDrive/BME/Rehabilitation Project/EMG data/EMG Gyro DelSys 2016 (V3)/'\n", | |
" ## data in OneDrive (macbook)\n", | |
" # p = '/Users/tulakan/OneDrive/BME/Rehabilitation Project/EMG data/EMG Gyro DelSys 2016 (V3)'\n", | |
" p = data_path\n", | |
"\n", | |
" dataDir = 'data'\n", | |
" for a, b, c in os.walk(p, topdown=False):\n", | |
" if dataDir in a:\n", | |
" # {name : filenames}\n", | |
" # path > p (parent), dataDir (data dir), subject name, file\n", | |
" # name\n", | |
" self.all_files.update({os.path.split(a)[1]: [\n", | |
" os.path.join(p, dataDir, os.path.split(a)[1], i) for i in c if 'DS_Store' not in i]})\n", | |
"# self.all_files.update({os.path.split(a)[1] : c})\n", | |
"\n", | |
" # return round [1(ef),2(bt),3(d)] and rep [1:20] of the collection\n", | |
" def rndrep(self, filename):\n", | |
" return int(filename.split('_')[0][-1]), int(filename.split('_')[-1].split('.')[0])\n", | |
"\n", | |
" def getSortedFileOf(self, who):\n", | |
" sortedFiles = []\n", | |
" rnd = [1, 2, 3]\n", | |
"\n", | |
" for i in rnd:\n", | |
" try:\n", | |
" sortedFiles.append(self.getSortedFileAtRnd(who, i))\n", | |
" except:\n", | |
" print(\"no round %i for %s\" % (i, who))\n", | |
"\n", | |
" return [item for sublist in sortedFiles for item in sublist]\n", | |
"\n", | |
" def getSortedFileAtRnd(self, who, rnd):\n", | |
" FilesOfWho = self.all_files[who]\n", | |
" tupleOfRndRep = np.array(list(map(self.rndrep, FilesOfWho)))\n", | |
" indexOfRnd = np.where(tupleOfRndRep[:, 0] == rnd)\n", | |
" sortedOfTuple = np.argsort(tupleOfRndRep[indexOfRnd][:, 1])\n", | |
"\n", | |
" return [FilesOfWho[indexOfRnd[0][i]] for i in sortedOfTuple]\n", | |
"\n", | |
" def getSubjectNames(self):\n", | |
" return list(self.all_files.keys())\n", | |
"\n", | |
" @staticmethod\n", | |
" def getRound(filename):\n", | |
" return int(filename.split('_')[0][-1])\n", | |
" \n", | |
" @staticmethod\n", | |
" def getRep(filename):\n", | |
" return int(filename.split('_')[-1].split('.')[0])\n", | |
"\n", | |
"\n", | |
"class EMGRound():\n", | |
" '''\n", | |
" Class to manipulate EMG data in each round,\n", | |
" initiation requires list of files to be read as pandas dataframe.\n", | |
" round1 = ex flex\n", | |
" round2 = bi tri\n", | |
" round3 = del\n", | |
" '''\n", | |
"\n", | |
" def __init__(self, listOfFiles):\n", | |
" self.listOfFiles = listOfFiles\n", | |
" self.dataroot = os.path.split(listOfFiles[0])[0]\n", | |
"\n", | |
" self.labelplot = {0: \"0. Baseline recording 1\",\n", | |
" 1: \"1. Baseline recording 2\",\n", | |
" 2: \"2. CENTER YAW OFF\",\n", | |
" 3: \"3. CENTER YAW ON\",\n", | |
" 4: \"4. CENTER ROLL OFF\",\n", | |
" 5: \"5. CENTER ROLL ON\",\n", | |
" 6: \"6. CENTER BOTH OFF\",\n", | |
" 7: \"7. CENTER BOTH ON\",\n", | |
" 8: \"8. LEFT YAW OFF\",\n", | |
" 9: \"9. LEFT YAW ON\",\n", | |
" 10: \"10. LEFT ROLL OFF\",\n", | |
" 11: \"11. LEFT ROLL ON\",\n", | |
" 12: \"12. LEFT BOTH OFF\",\n", | |
" 13: \"13. LEFT BOTH ON\",\n", | |
" 14: \"14. RIGHT YAW OFF\",\n", | |
" 15: \"15. RIGHT YAW ON\",\n", | |
" 16: \"16. RIGHT ROLL OFF\",\n", | |
" 17: \"17. RIGHT ROLL ON\",\n", | |
" 18: \"18. RIGHT BOTH OFF\",\n", | |
" 19: \"19. RIGHT BOTH ON\"}\n", | |
"\n", | |
" self.df = pd.DataFrame()\n", | |
"\n", | |
" for no, file in enumerate(self.listOfFiles):\n", | |
" # print percent progress as finished read files\n", | |
" pc = ((no + 1) / len(listOfFiles)) * 100\n", | |
" print(\"\\r{0}% done..\".format(int(pc)), end=\"\")\n", | |
"\n", | |
" rnd = EMGfiles.getRound(file)\n", | |
" rep = EMGfiles.getRep(file)\n", | |
"\n", | |
" # round 1: ex-flex\n", | |
" # round 2: bi-tri\n", | |
" # round 3: del\n", | |
"# self.df = self.df.append(pd.read_csv(file, skiprows=skr).assign(rep=rep-1), ignore_index=True)\n", | |
" try:\n", | |
" self.df = self.df.append(pd.read_csv(\n", | |
" file, skiprows=11).assign(rep=rep - 1), ignore_index=True)\n", | |
" except:\n", | |
" self.df = self.df.append(pd.read_csv(\n", | |
" file, skiprows=19).assign(rep=rep - 1), ignore_index=True)\n", | |
"\n", | |
" self.df = self._getOnlyEMG()\n", | |
"\n", | |
" def getDataRootPath(self):\n", | |
" return self.dataroot\n", | |
"\n", | |
" def getSamplingRate(self):\n", | |
" with open(self.listOfFiles[0], 'r') as toread:\n", | |
" temp = toread.readline()\n", | |
" temp2 = []\n", | |
"\n", | |
" for i in temp.split():\n", | |
" try:\n", | |
" temp2.append(float(i))\n", | |
" except:\n", | |
" pass\n", | |
"\n", | |
" self.fs = temp2[1]\n", | |
"\n", | |
" return self.fs\n", | |
"\n", | |
" def _getOnlyEMG(self):\n", | |
" return self.df[[i for i in list(self.df.columns) if 'EMG' in i or 'rep' in i]]\n", | |
"\n", | |
" def getColumnNames(self):\n", | |
" return [i for i in list(self.df.columns) if 'EMG' in i]\n", | |
"\n", | |
" def plotBig(self):\n", | |
" col = [i for i in self.getColumnNames() if 'rep' not in i]\n", | |
" rep = self.labelplot.keys()\n", | |
" q = itertools.product(col, rep)\n", | |
"\n", | |
" for i in q:\n", | |
" temp = np.array(self.df.query(\"rep == {0}\".format(i[1]))[i[0]])\n", | |
" # cut head and tail out for 1000 each (~ 1 sec)\n", | |
" # then scale up by multiplying with 1000 (V -> mV)\n", | |
" temp = temp[1000:-1000] * 1000\n", | |
"\n", | |
" # new figure in each round plot\n", | |
" plt.figure(figsize=[18, 6])\n", | |
"\n", | |
" # plot raw data on the left\n", | |
" plt.subplot(1, 2, 1)\n", | |
"\n", | |
" plt.title(self.labelplot[i[1]] + \" of {0}\".format(i[0]))\n", | |
" # plot very raw data!\n", | |
" plt.plot(temp)\n", | |
"\n", | |
" # plot raw data pass bandpass filter 20-500 Hz\n", | |
" temp2 = sigfilter.bbf(\n", | |
" temp, lowcut=20, highcut=500, fs=self.getSamplingRate())\n", | |
" plt.plot(np.array(temp2))\n", | |
"\n", | |
" # plot rms data on the right\n", | |
" plt.subplot(1, 2, 2)\n", | |
"\n", | |
" temp3 = sigfilter.rms_vectorized(temp, win_size=250)\n", | |
" temp4 = sigfilter.rms_vectorized(temp2, win_size=250)\n", | |
"\n", | |
" plt.plot(temp3)\n", | |
" plt.plot(temp4)\n", | |
"\n", | |
" if i[1] == 0:\n", | |
" bs1, bs2 = 0, 0\n", | |
" bs1 = temp4.mean()\n", | |
" elif i[1] == 1:\n", | |
" bs2 = temp4.mean()\n", | |
"\n", | |
" plt.plot(list(itertools.repeat(bs1, len(temp4))), label=\"bs1\")\n", | |
" plt.legend()\n", | |
" plt.plot(list(itertools.repeat(bs2, len(temp4))), label=\"bs2\")\n", | |
" plt.legend()\n", | |
" plt.plot(list(itertools.repeat(temp4.mean(), len(temp4))),\n", | |
" label=\"data mean\")\n", | |
" plt.legend()\n", | |
" plt.title('avg: ' + str(temp3.mean()) + ' mV vs ' +\n", | |
" str(temp4.mean()) + ' mV [highpass 20Hz]')\n", | |
"\n", | |
" def getRMS(self):\n", | |
" def _trim(df, head, tail):\n", | |
" return df[head:-tail]\n", | |
"\n", | |
" def _scaleUp(df, scaleLvl):\n", | |
" repCol = df[\"rep\"]\n", | |
" dat = df.drop(\"rep\", axis=1) * scaleLvl\n", | |
" dat = dat.assign(rep=repCol)\n", | |
" return dat\n", | |
"\n", | |
" # bandpass filter of 20-500 Hz\n", | |
" def _bandpassFilter(df):\n", | |
" return sigfilter.bbf(df, lowcut=20, highcut=500, fs=self.getSamplingRate())\n", | |
"\n", | |
" def _rmsFilter(df):\n", | |
" return sigfilter.rms_vectorized(np.array(df), win_size=250)\n", | |
"\n", | |
" # scaling up V -> mV\n", | |
" dat = _scaleUp(self.df, 1000)\n", | |
" out_rms = pd.DataFrame()\n", | |
"\n", | |
" for i in range(20):\n", | |
" temp2 = dat.query(\"rep == %s\" % i).drop(\"rep\", axis=1)\n", | |
" temp = _trim(temp2, 1000, 1000)\n", | |
"\n", | |
" # get columns name\n", | |
" col = temp.columns\n", | |
"\n", | |
" # apply bandpass filter\n", | |
" filtered = pd.DataFrame(np.apply_along_axis(\n", | |
" _bandpassFilter, 0, temp), columns=col)\n", | |
"\n", | |
" # get rms data\n", | |
" rmsed = pd.DataFrame(np.apply_along_axis(\n", | |
" _rmsFilter, 0, filtered), columns=col)\n", | |
" rmsed['rep'] = i\n", | |
"\n", | |
" # append to out_rms to return\n", | |
" out_rms = out_rms.append(rmsed, ignore_index=True)\n", | |
"\n", | |
" return out_rms\n", | |
"\n", | |
" def plotRMS(self):\n", | |
" temp = self.getRMS()\n", | |
"\n", | |
" for i in range(20):\n", | |
" temp1 = temp.query(\"rep == %s\" % i).drop(\"rep\", axis=1)\n", | |
"\n", | |
" # plotting\n", | |
" plt.figure(figsize=[18, 4])\n", | |
"\n", | |
" for col in temp1:\n", | |
" plt.title(self.labelplot[i])\n", | |
" plt.plot(temp1[col], label=col.split()\n", | |
" [0] + \" \" + col.split()[1])\n", | |
" plt.legend()\n", | |
"\n", | |
" plt.tight_layout()\n", | |
"\n", | |
" def plotDiff(self):\n", | |
" ''' Only available for CENTER for now '''\n", | |
"\n", | |
" axes = ['YAW', 'ROLL', 'BOTH']\n", | |
" t = self.getRMS()\n", | |
"\n", | |
" for a in axes:\n", | |
" plt.figure()\n", | |
" plt.title(a + ' CENTER')\n", | |
"\n", | |
" for i in range(20):\n", | |
" temp = t[t['rep'] == i].drop('rep', axis=1).mean()\n", | |
"\n", | |
" if a in self.labelplot[i] and 'CENTER' in self.labelplot[i]:\n", | |
" if 'ON' in self.labelplot[i]:\n", | |
" plt.bar(range(len(temp)), temp, alpha=0.5,\n", | |
" color='r', label=\"ON\", align='center')\n", | |
" else:\n", | |
" plt.bar(range(len(temp)), temp, alpha=0.5,\n", | |
" label=\"OFF\", align='center')\n", | |
" xlabel = [i.split()[0] + ' ' + i.split()[1]\n", | |
" for i in list(t.drop('rep', axis=1).columns)]\n", | |
" plt.xticks(range(len(temp)), xlabel)\n", | |
" plt.legend()\n", | |
"\n", | |
" def getAvg(self, wheel_pos):\n", | |
" axes = ['YAW', 'ROLL', 'BOTH']\n", | |
" t = self.getRMS()\n", | |
" avg_df = pd.DataFrame()\n", | |
"\n", | |
" for a in axes:\n", | |
" for i in range(20):\n", | |
" temp = t[t['rep'] == i].drop('rep', axis=1).mean()\n", | |
"\n", | |
" if a in self.labelplot[i] and wheel_pos in self.labelplot[i]:\n", | |
" if 'ON' in self.labelplot[i]:\n", | |
" avg_df = avg_df.append(pd.DataFrame([temp]).assign(\n", | |
" action=a + \" ON\"), ignore_index=True)\n", | |
"\n", | |
" else:\n", | |
" avg_df = avg_df.append(pd.DataFrame([temp]).assign(\n", | |
" action=a + \" OFF\"), ignore_index=True)\n", | |
"\n", | |
" return avg_df\n", | |
"\n", | |
"\n", | |
"class EMGData():\n", | |
" '''\n", | |
" Class to collect EMG data\n", | |
" dc = dataclass\n", | |
" '''\n", | |
"\n", | |
" def __init__(self, filename1, filename2, filename3):\n", | |
" self.dc1 = EMGRound(filename1) # round 1: ex flex\n", | |
" self.dc2 = EMGRound(filename2) # round 2: bi tri\n", | |
" self.dc3 = EMGRound(filename3) # round 3: del\n" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Data Path\n", | |
"\n", | |
"### dir structure\n", | |
"root (p) > data > subject (kai)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"## data in the same dir\n", | |
"# p = os.curdir\n", | |
"## data in OneDrive (windows)\n", | |
"# p = 'D:/OneDrive/BME/Rehabilitation Project/EMG data/EMG Gyro DelSys 2016 (V3)/'\n", | |
"## data in OneDrive (mac)\n", | |
"# p = '/Users/tulakan/OneDrive/BME/Rehabilitation Project/EMG data/EMG Gyro DelSys 2016 (V3)'\n", | |
"\n", | |
"p = '/Users/tulakan/OneDrive/BME/Rehabilitation Project/EMG data/EMG Gyro DelSys 2016 (V3)'" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Get and parse data then collect in emgClass" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"f = EMGfiles(p)\n", | |
"subject = 'kai'\n", | |
"\n", | |
"f1 = f.getSortedFileAtRnd(subject, 1)\n", | |
"f2 = f.getSortedFileAtRnd(subject, 2)\n", | |
"f3 = f.getSortedFileAtRnd(subject, 3)\n", | |
"\n", | |
"emgClass = EMGData(f1, f2, f3)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Try out some methods" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"# get RMSed values\n", | |
"# emgClass.dc1.getRMS()\n", | |
"# get average RMSed values of wheel position center\n", | |
"# emgClass.dc2.getAvg('CENTER') \n", | |
"# try plot\n", | |
"# emgClass.dc3.plotRMS()\n", | |
"emgClass.dc3.plotBig()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.5.2" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 0 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment