Created
June 10, 2022 05:47
-
-
Save rseydam/7b141c52bbcecf0811426b6a54bf3658 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# 2 population spiking model | |
# two coupled population with convolution inh <-> exc | |
# structures, stepper, plotting | |
### | |
using Setfield | |
using Parameters | |
using Random | |
using SparseArrays | |
using LinearAlgebra | |
using NNlib #gives a bunch of activation function also used by flux.jl | |
using FFTW | |
using JLD2 | |
## | |
#heaviside function | |
hs(x::AbstractFloat) = ifelse(x < 0, zero(x), one(x)) | |
hs2(x::AbstractFloat) = ifelse(x <= 0, zero(x), one(x)) | |
# connectivity matrix in shifted coordinates for convolutions | |
function connectivity(n, l, w) | |
conn = zeros(n*n) # initialize weights (here as a long vector) | |
shifted = (fftfreq(n)*n) # get shifted coordinates | |
## | |
for i in 1:n | |
for j in 1:n | |
c = n * (i-1) + j # coordinate relation matrix to vector | |
r = sqrt(shifted[i]*shifted[i] + shifted[j]*shifted[j]) # compute distance | |
if r <= l # check whether r below given coupling range l | |
conn[c] = w # at the moment the sign is according to the w_ex w_in values | |
#here we could also use different kernel for grated synapses! | |
else | |
conn[c] = 0.0 | |
end | |
end | |
end | |
return reshape(conn,(n,n)) # return as nxn matrix | |
end | |
## | |
# 2 population spiking model | |
# with fixed typing here | |
@with_kw struct SPnet2 | |
# | |
nsq::Int64 = 64 # n is made a perfect square to have a 2d grid | |
n::Int64 = nsq*nsq # neural sheet size | |
dt::Float64 = 1.0 # timestep size/simulation step | |
t::Vector{UInt64} = [0] # time index of the current step t[1]*dt = time | |
τex::Float64 = 40.0 # neuron time constant | |
τin::Float64 = 20.0 # neuron time constant | |
l_in::Float64 = 12.0 # range of inhibition | |
w_in::Float64 = -1.0 # strength & sign inhibition | |
l_ex::Float64 = 4.0 # range of excitation | |
w_ex::Float64 = 1.0 # strength excitation | |
delay_in::Float64 = 2.0 # synaptic delay | |
delay_etoi::Float64 = 2.0 # synaptic delay | |
delay_etoe::Float64 = 5.0 # synaptic delay | |
a_ex::Float64 = 1.1 # additional external drive | |
a_in::Float64 = 0.8 # additional external drive | |
a_m::Float64 = 0.0 # driving amplitude | |
ω_m::Float64 = 0.0 # driving frequency | |
d_0::Float64 = 0.0 #spatially heterogeneous driving | |
λ_B::Float64 = 0.0 # d(x,t) = d_0*cos(x/λ_B *2π + ϕ_t + Θ) | |
Θ::Float64 = 0.0 # later take 2π/λ_B | |
#seed and rng, noise parameters | |
rngseed::UInt64 = rand(UInt64) # random initial seed as default | |
rng::MersenneTwister = MersenneTwister(rngseed) # setting rng | |
rngseedIC::UInt64 = rand(UInt64) # random initial seed as default | |
rngIC::MersenneTwister = MersenneTwister(rngseedIC) # setting rng for IC | |
σ_ex::Float64 = 0.01 # diffusion in ex -- careful do not reset after generation/need to reset ns_ex too | |
σ_in::Float64 = 0.01 # diffusion in in -- careful do not reset after generation/need to reset ns_in too | |
ns_ex::Float64 = σ_ex * sqrt(dt/τex) # prefactor for euler step | |
ns_in::Float64 = σ_in * sqrt(dt/τin) | |
# create the fourier transformed coupling matrices / only done once | |
w_four_ex::Matrix{ComplexF64} = fft( connectivity(nsq, l_ex, w_ex) ) | |
w_four_in::Matrix{ComplexF64} = fft( connectivity(nsq, l_in, w_in) ) | |
# we could also make this w_four[1:2]... could be helpfull dealing with more layers | |
# note that in FFTW.jl normalization is fft^-1(fft(x)) = x | |
# distances in the history register | |
hist_in::Int64 = max(round(delay_in / dt), 1) | |
hist_etoi::Int64 = max(round(delay_etoi / dt), 1) | |
hist_etoe::Int64 = max(round(delay_etoe / dt), 1) | |
# the maximum will decide on the size of the history storage | |
# calculate length of history interval | |
hist_size::Int64 = convert(Int,max(max(hist_in, hist_etoi), hist_etoe)) | |
# in-place transform plans and storage | |
s_temp::Matrix{ComplexF64} = fft(zeros(nsq,nsq)) # temporary spike matrix used for in-place operations (complex matrix) | |
fw_plan! = plan_fft!( s_temp; flags=FFTW.PATIENT, timelimit=Inf) # operates in-place on s_temp | |
rv_plan! = plan_ifft!( s_temp ; flags=FFTW.PATIENT, timelimit=Inf) | |
#note: these don't make use of the fact that we transform a real function that would need | |
#only half the space but rfft! doesn't exists unfortunately - it is still many times faster than | |
#recreating the arrays so i use regular fft! (in-place) | |
#history for inh and exc | |
sw_i::Array{Float64, 3} = zeros(nsq, nsq, hist_size) | |
sw_e::Array{Float64, 3} = zeros(nsq, nsq, hist_size) | |
#we can also make this sw[1:2] ... | |
#current_hist / at initialization | |
hist_idx::Vector{Int64} = [1,1,1,1] # hist_now, past_in, past_etoi, past_etoe -> history 'pointers' | |
#Allocate space for spikes convolved with kernels | |
sw_to_ex::Matrix{Float64} = zeros(nsq,nsq) # collecting all input to exc pop from coupling | |
sw_to_in::Matrix{Float64} = zeros(nsq,nsq) # collecting all input to inh pop from coupling | |
#for more layers also sw_to[1,2] ... | |
# inhibitory, excitatory | |
psi::Vector{Matrix{Float64}} = [ rand(rngIC,nsq,nsq), rand(rngIC,nsq,nsq) ] # membrane potentials | |
s::Vector{Matrix{Float64}} = [ zeros(nsq,nsq), zeros(nsq,nsq) ] # current spikes ('active' after latest step) | |
end | |
## | |
## | |
# this function performes one iteration step of the spiking model | |
function SPnet2Step!(net,ϕ_t) | |
@unpack psi, s, τex, τin, n, nsq, t, | |
dt, w_four_ex, w_four_in, a_ex, a_in, | |
fw_plan!, rv_plan!, | |
hist_in, hist_etoi, hist_etoe, hist_size, hist_idx, | |
s_temp, sw_to_in, sw_to_ex, sw_i, sw_e, | |
rng, ns_ex, ns_in, | |
a_m, ω_m, d_0, λ_B, Θ = net | |
# obtain appropriate spike-coupling-convolution indices | |
hist_idx[2] = mod(hist_idx[1] - hist_in , hist_size) + 1 # past_in | |
hist_idx[3] = mod(hist_idx[1] - hist_etoi, hist_size) + 1 # past_etoi | |
hist_idx[4] = mod(hist_idx[1] - hist_etoe, hist_size) + 1 # past_etoe | |
# Retrieve spike convolutions from the appropriate history register | |
sw_to_in .= @view sw_e[:,:,hist_idx[3]] # excitatory to inhibitory | |
sw_to_ex .= @view sw_i[:,:,hist_idx[2]] # inhibitory to excitatory | |
sw_to_ex .+= @view sw_e[:,:,hist_idx[4]] # excitatory to excitatory | |
# iterate the system forward in time | |
psi[1] .+= ( .-psi[1] .+ a_in .+ a_m*sin(ω_m * t[1]*dt) ) .* (dt/τin) .+ | |
randn.(rng) .* ns_in .+ sw_to_in ./τin | |
#add forcing in 'vertical' direction | |
# j'th column psi[1][:,j] | |
if d_0 != 0.0 | |
for j in 1:nsq #for each column add forcing to each unit (both populations) | |
psi[1][:,j] .+= d_0*cos(j*2π/λ_B + Θ + ϕ_t) .* (dt/τin) | |
psi[2][:,j] .+= d_0*cos(j*2π/λ_B + Θ + ϕ_t) .* (dt/τex) | |
end | |
end | |
psi[2] .+= ( .-psi[2] .+ a_ex .+ a_m*sin(ω_m * t[1]*dt) ) .* (dt/τex) .+ | |
randn.(rng) .* ns_ex .+ sw_to_ex ./τex | |
#println("current hist pointer:", hist_idx[1]) | |
for p in eachindex(psi) | |
# setting lower bound and spikes psis | |
map!(psi[p], @view psi[p][:]) do x | |
if x >= 1.0 | |
x = 0.0 | |
elseif x < -2.0 | |
x = -2.0 | |
end | |
return x | |
end | |
# identify spikes psi[p][:] == 0.0 | |
map!(s[p], @view psi[p][:] ) do x | |
if x == 0.0 #if for any reason x didn't change we detect spike although no occured!!! | |
x = 1.0 | |
else | |
x = 0.0 | |
end | |
return x | |
end | |
#assign temporary spikes | |
s_temp .= s[p] .+ 0.0*im # i am not sure if i even need them maybe using s[p] is enough | |
#println("sum s_temp of population $p: ",sum(s_temp)) | |
#Convolve spikes with kernels and load into the proper sw register | |
if p==1 #sw_i[:,:,hist_idx[1]] .= rv_plan * ( ( fw_plan * s[p] ) .* w_four_in ) | |
fw_plan! * s_temp # forward fft in-place of s_temp | |
s_temp .*= w_four_in # elementwise product with coupling matrix in fourier space | |
rv_plan! * s_temp # reverse transform in-place | |
sw_i[:,:,hist_idx[1]] .= real.(s_temp) # copy convolved spikes to sw register | |
else # sw_e[:,:,hist_idx[1]] .= rv_plan * ( ( fw_plan * s[p] ) .* w_four_ex ) | |
fw_plan! * s_temp | |
s_temp .*= w_four_ex | |
rv_plan! * s_temp | |
sw_e[:,:,hist_idx[1]] .= real.(s_temp) # copy convolved spikes to sw register | |
end | |
end | |
#advance time counter | |
t[1] += 1 | |
hist_idx[1] = (hist_idx[1] + 1) % hist_size # update history 'pointer' | |
if hist_idx[1]==0 hist_idx[1] = hist_size end | |
return nothing | |
end | |
## | |
## | |
# reinitialize the system | |
function reinit_IC_same_rng(spnet) | |
@unpack s, psi, nsq, rngseed, rngseedIC, rng, rngIC, sw_i, sw_e, hist_idx, hist_size = spnet | |
Random.seed!(rng, rngseed) #reinits rng | |
Random.seed!(rngIC, rngseed) #reinits rng | |
#here we can set the exact same IC as previously used | |
psi .= [ rand(rngIC,nsq,nsq), rand(rngIC,nsq,nsq) ] | |
s .= [ zeros(nsq,nsq), zeros(nsq,nsq) ] | |
#also reset history intervals and 'pointer' | |
sw_i .= zeros(nsq, nsq, hist_size) | |
sw_e .= zeros(nsq, nsq, hist_size) | |
hist_idx .= [1,1,1,1] | |
return nothing | |
end | |
## | |
# here is another piece of code that i found important | |
# when you want to store a structure with FFT plans they somehow break julia when you load them up | |
# i didn't have the nerve to figure out why so i just wrote a function to load up the stored | |
# structure and reinitialize the FFT plans | |
# loading system from disc (jdl2 file) require special care with FFTplans | |
function load_system(fn,sn) "fn is the name of jdl2 file, sn is the object name" | |
spnet = load(fn, sn); # this is the load function from | |
# we need to create the plans new because they are corrupted and crash julia | |
spnet = @set spnet.fw_plan! = plan_fft!( spnet.s_temp; flags=FFTW.PATIENT, timelimit=Inf); | |
spnet = @set spnet.rv_plan! = plan_ifft!( spnet.s_temp; flags=FFTW.PATIENT, timelimit=Inf); | |
return spnet | |
end | |
# for storing and loading i use JLD2.jl which is a great package really | |
# i also use it to store spiking data of the system to disc in jld2 format | |
# for instance that allows me to easily store sparse matrices |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
syntax highlighted version: https://gist.github.com/tfiers/19d15bdb00d2e7beaa0fb94be149fe5c