Created
October 24, 2025 18:01
-
-
Save joacorapela/b895f7232027efac2aba23b838056eea to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import sys | |
| import numpy as np | |
| import plotly.graph_objects as go | |
| def sample_HOs(t, sigma, freq_HO, HO0): | |
| eta = np.random.normal(scale=sigma, size=len(t)) | |
| HOs = np.empty(shape=len(t)) | |
| HOs[0] = HO0 | |
| for i in range(0, len(t)-1): | |
| HOs[i+1] = 2 * np.pi * freq_HO * t[i] + eta[i] | |
| return HOs | |
| def rereferenceHOs(orientation, HOs): | |
| orientationV = np.exp(1j * orientation) | |
| HOsV = np.exp(1j * HOs) | |
| reReferencedOrientations = np.empty_like(HOs) | |
| for i, HOV in enumerate(HOsV): | |
| reReferencedOrientations[i] = np.angle(orientationV * np.conj(HOV)) | |
| return reReferencedOrientations | |
| def calculateMRL(HOs): | |
| HOsV = np.exp(1j * HOs) | |
| mean_HOsV = np.mean(HOsV) | |
| mrl = np.abs(mean_HOsV) | |
| return mrl | |
| def main(args): | |
| T = 1 | |
| sample_rate = 1000 | |
| lam = 0.1 | |
| freq_HO = 0.1 | |
| sigma = 0.0 | |
| HO0 = 0.0 # initial head orientation | |
| orientation0 = np.pi/2 # orientation of sink 0 | |
| orientation1 = -np.pi/2 # orientation of sink 1 | |
| tol = 1e-6 | |
| t = np.arange(0, T, 1.0/sample_rate) | |
| # sample spike counts with a low rate to ensure that there is at most one spike per bin | |
| spike_counts = np.random.poisson(lam=lam, size=len(t)) | |
| spike_samples = np.nonzero(spike_counts) | |
| spike_times = t[spike_samples] # extract spike times | |
| # sample HOs restricted to a small angle (for clear visualization) | |
| HOs = sample_HOs(t=t, sigma=sigma, freq_HO=freq_HO, HO0=HO0) | |
| # extract the spike related HOs | |
| HOsAtSpikeTimes = HOs[spike_samples] | |
| # re reference HOs wrt sink 0 | |
| rr0HOsAtSpikeTimes = rereferenceHOs(orientation=orientation0, | |
| HOs=HOsAtSpikeTimes) | |
| # re reference HOs wrt sink 1 | |
| rr1HOsAtSpikeTimes = rereferenceHOs(orientation=orientation1, | |
| HOs=HOsAtSpikeTimes) | |
| # compute MRL for the spike related HOs | |
| mrlHOsAtSpikeTimes = np.abs(np.mean(np.exp(1j * HOsAtSpikeTimes))) | |
| # compute MRL for re referenced HOs wrt sink 0 | |
| mrlRR0HOsAtSpikeTimes = np.abs(np.mean(np.exp(1j * rr0HOsAtSpikeTimes))) | |
| # compute MRL for re referenced HOs wrt sink 1 | |
| mrlRR1HOsAtSpikeTimes = np.abs(np.mean(np.exp(1j * rr1HOsAtSpikeTimes))) | |
| # assert that the MRL for the spike related HOs equals that for the re | |
| # referenced HOs wrt sink 0 | |
| assert(np.abs(mrlHOsAtSpikeTimes - mrlRR0HOsAtSpikeTimes) < tol) | |
| # assert that the MRL for the spike related HOs equals that for the re | |
| # referenced HOs wrt sink 1 | |
| assert(np.abs(mrlHOsAtSpikeTimes - mrlRR1HOsAtSpikeTimes) < tol) | |
| # plot results | |
| # plot HOs vs time with superimposed spike times | |
| fig = go.Figure() | |
| trace = go.Scatter(x=t, y=HOs) | |
| fig.add_trace(trace) | |
| for spike_time in spike_times: | |
| fig.add_vline(x=spike_time) | |
| fig.update_xaxes(title="Time (sec)") | |
| fig.update_yaxes(title="Head Orientation (radians)") | |
| fig.show() | |
| # plot original and re referenced HOs | |
| fig = go.Figure() | |
| trace = go.Scatter(x=np.cos(2*np.pi*t), y=np.sin(2*np.pi*t), | |
| line=dict(color="gray"), mode="lines", | |
| showlegend=False) | |
| fig.add_trace(trace) | |
| trace = go.Scatter(x=np.cos(HOsAtSpikeTimes), | |
| y=np.sin(HOsAtSpikeTimes), | |
| name="HOs", | |
| line=dict(color="blue"), mode="markers") | |
| fig.add_trace(trace) | |
| trace = go.Scatter(x=np.cos(rr0HOsAtSpikeTimes), | |
| y=np.sin(rr0HOsAtSpikeTimes), | |
| name="rr0 HOs", | |
| line=dict(color="orange"), mode="markers") | |
| fig.add_trace(trace) | |
| trace = go.Scatter(x=np.cos(rr1HOsAtSpikeTimes), | |
| y=np.sin(rr1HOsAtSpikeTimes), | |
| name="rr1 HOs", | |
| line=dict(color="green"), mode="markers") | |
| fig.add_trace(trace) | |
| fig.update_layout(yaxis_scaleanchor="x") | |
| fig.show() | |
| # plot mean resultant lengths for original and re referenced HOs | |
| fig = go.Figure() | |
| trace = go.Bar(x=["HOs", "rr0 HOs", "rr1 HOs"], y=[mrlHOsAtSpikeTimes, mrlRR0HOsAtSpikeTimes, mrlRR1HOsAtSpikeTimes]) | |
| fig.add_trace(trace) | |
| fig.update_yaxes(title="Mean Resultant Lenght") | |
| fig.show() | |
| breakpoint() | |
| if __name__ == "__main__": | |
| main(sys.argv) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment