Skip to content

Instantly share code, notes, and snippets.

@joacorapela
Created October 24, 2025 18:01
Show Gist options
  • Save joacorapela/b895f7232027efac2aba23b838056eea to your computer and use it in GitHub Desktop.
Save joacorapela/b895f7232027efac2aba23b838056eea to your computer and use it in GitHub Desktop.
import sys
import numpy as np
import plotly.graph_objects as go
def sample_HOs(t, sigma, freq_HO, HO0):
eta = np.random.normal(scale=sigma, size=len(t))
HOs = np.empty(shape=len(t))
HOs[0] = HO0
for i in range(0, len(t)-1):
HOs[i+1] = 2 * np.pi * freq_HO * t[i] + eta[i]
return HOs
def rereferenceHOs(orientation, HOs):
orientationV = np.exp(1j * orientation)
HOsV = np.exp(1j * HOs)
reReferencedOrientations = np.empty_like(HOs)
for i, HOV in enumerate(HOsV):
reReferencedOrientations[i] = np.angle(orientationV * np.conj(HOV))
return reReferencedOrientations
def calculateMRL(HOs):
HOsV = np.exp(1j * HOs)
mean_HOsV = np.mean(HOsV)
mrl = np.abs(mean_HOsV)
return mrl
def main(args):
T = 1
sample_rate = 1000
lam = 0.1
freq_HO = 0.1
sigma = 0.0
HO0 = 0.0 # initial head orientation
orientation0 = np.pi/2 # orientation of sink 0
orientation1 = -np.pi/2 # orientation of sink 1
tol = 1e-6
t = np.arange(0, T, 1.0/sample_rate)
# sample spike counts with a low rate to ensure that there is at most one spike per bin
spike_counts = np.random.poisson(lam=lam, size=len(t))
spike_samples = np.nonzero(spike_counts)
spike_times = t[spike_samples] # extract spike times
# sample HOs restricted to a small angle (for clear visualization)
HOs = sample_HOs(t=t, sigma=sigma, freq_HO=freq_HO, HO0=HO0)
# extract the spike related HOs
HOsAtSpikeTimes = HOs[spike_samples]
# re reference HOs wrt sink 0
rr0HOsAtSpikeTimes = rereferenceHOs(orientation=orientation0,
HOs=HOsAtSpikeTimes)
# re reference HOs wrt sink 1
rr1HOsAtSpikeTimes = rereferenceHOs(orientation=orientation1,
HOs=HOsAtSpikeTimes)
# compute MRL for the spike related HOs
mrlHOsAtSpikeTimes = np.abs(np.mean(np.exp(1j * HOsAtSpikeTimes)))
# compute MRL for re referenced HOs wrt sink 0
mrlRR0HOsAtSpikeTimes = np.abs(np.mean(np.exp(1j * rr0HOsAtSpikeTimes)))
# compute MRL for re referenced HOs wrt sink 1
mrlRR1HOsAtSpikeTimes = np.abs(np.mean(np.exp(1j * rr1HOsAtSpikeTimes)))
# assert that the MRL for the spike related HOs equals that for the re
# referenced HOs wrt sink 0
assert(np.abs(mrlHOsAtSpikeTimes - mrlRR0HOsAtSpikeTimes) < tol)
# assert that the MRL for the spike related HOs equals that for the re
# referenced HOs wrt sink 1
assert(np.abs(mrlHOsAtSpikeTimes - mrlRR1HOsAtSpikeTimes) < tol)
# plot results
# plot HOs vs time with superimposed spike times
fig = go.Figure()
trace = go.Scatter(x=t, y=HOs)
fig.add_trace(trace)
for spike_time in spike_times:
fig.add_vline(x=spike_time)
fig.update_xaxes(title="Time (sec)")
fig.update_yaxes(title="Head Orientation (radians)")
fig.show()
# plot original and re referenced HOs
fig = go.Figure()
trace = go.Scatter(x=np.cos(2*np.pi*t), y=np.sin(2*np.pi*t),
line=dict(color="gray"), mode="lines",
showlegend=False)
fig.add_trace(trace)
trace = go.Scatter(x=np.cos(HOsAtSpikeTimes),
y=np.sin(HOsAtSpikeTimes),
name="HOs",
line=dict(color="blue"), mode="markers")
fig.add_trace(trace)
trace = go.Scatter(x=np.cos(rr0HOsAtSpikeTimes),
y=np.sin(rr0HOsAtSpikeTimes),
name="rr0 HOs",
line=dict(color="orange"), mode="markers")
fig.add_trace(trace)
trace = go.Scatter(x=np.cos(rr1HOsAtSpikeTimes),
y=np.sin(rr1HOsAtSpikeTimes),
name="rr1 HOs",
line=dict(color="green"), mode="markers")
fig.add_trace(trace)
fig.update_layout(yaxis_scaleanchor="x")
fig.show()
# plot mean resultant lengths for original and re referenced HOs
fig = go.Figure()
trace = go.Bar(x=["HOs", "rr0 HOs", "rr1 HOs"], y=[mrlHOsAtSpikeTimes, mrlRR0HOsAtSpikeTimes, mrlRR1HOsAtSpikeTimes])
fig.add_trace(trace)
fig.update_yaxes(title="Mean Resultant Lenght")
fig.show()
breakpoint()
if __name__ == "__main__":
main(sys.argv)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment