-
-
Save ingted/45e807bdef24fc2d1f6ec4be484db72c to your computer and use it in GitHub Desktop.
fsnn
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #r "nuget: TorchSharp" | |
| #r "nuget: TorchSharp-cuda-windows, 0.105.1" | |
| open System | |
| open TorchSharp | |
| //open type TorchSharp.torch | |
| let device = if torch.cuda_is_available() then torch.CUDA else torch.CPU | |
| printfn "Running on %A" device | |
| // ======= 參數設定 ======= | |
| let g = 3 // group 數量 | |
| let n = 200 // 平均 cell 數量 | |
| let m = 30 // ±波動 | |
| let srange = 1000 // 突觸區域大小 (索引範圍) | |
| let j = 200 // 平均樹突軸突數量基準 | |
| let tickCount = 50 // tick 數量 (t1~tn) | |
| let o = 10 // 指定輸出軸突數量 | |
| let rnd = Random() | |
| // ======= 隨機生成各 group 結構 ======= | |
| type CellGroup = | |
| { dendriteConn : torch.Tensor // [cell, dendrites] 樹突連接 (稀疏索引) | |
| axonConn : torch.Tensor // [cell, axons] 軸突連接 | |
| synapseMap : torch.Tensor // [axon, dendrite] 突觸連接圖 | |
| outputs : torch.Tensor } // 指定為輸出的軸突索引 | |
| let makeGroup (idx:int) = | |
| let cellCount = rnd.Next(n - m, n + m) | |
| let dendriteCount = rnd.Next(j - 50, j + 50) | |
| let axonCount = rnd.Next(j - 50, j + 50) | |
| // 樹突、軸突、突觸隨機連接 | |
| let dendriteConn = torch.randint(srange, [| cellCount; dendriteCount |], device=device) | |
| let axonConn = torch.randint(srange, [| cellCount; axonCount |], device=device) | |
| let synapseMap = torch.randint(2, [| axonCount; dendriteCount |], device=device) | |
| // 指定部分軸突為輸出 | |
| let rp = torch.randperm(axonCount, device=device) | |
| let rangeIdx : torch.TensorIndex = System.Range(Index.FromStart 0, Index.FromEnd(Math.Min(o, (axonCount-1)))) | |
| let outputs = rp[rangeIdx] | |
| { dendriteConn = dendriteConn | |
| axonConn = axonConn | |
| synapseMap = synapseMap | |
| outputs = outputs } | |
| // ======= 所有群組初始化 ======= | |
| let groups = [| for i in 0 .. g-1 -> makeGroup i |] | |
| // ======= Tick 信號產生器 ======= | |
| // 每 tick 產生一個刺激向量 [srange],模擬外部輸入 | |
| let signalGenerator () = | |
| torch.randn([| int64 srange |], device=device) | |
| let t = 1 | |
| let gi = 0 | |
| // ======= Tick 更新模擬 ======= | |
| for t in 1 .. tickCount do | |
| let signal = signalGenerator() | |
| // 每組群組的突觸響應 | |
| for gi in 0 .. g-1 do | |
| let gcell = groups[gi] | |
| // 模擬樹突輸入總和 | |
| //let dendInput = | |
| // torch.index_select(signal, 0, (gcell.dendriteConn.flatten())) | |
| // .view(gcell.dendriteConn.shape) | |
| // .sum(1, keepdim=true) | |
| // 模擬樹突輸入總和 | |
| let gathered = | |
| torch.index_select(signal, 0, gcell.dendriteConn.flatten()) | |
| .view([| gcell.dendriteConn.shape.[0]; gcell.dendriteConn.shape.[1] |]) | |
| // 對“cell 維”(dim=0) 做加總,得到每個 dendrite 的合併輸入 → [dendriteCount] | |
| let dendInputVec = gathered.sum(0, keepdim=false) | |
| // 轉成 [dendriteCount, 1] 才能與 synapseMap 相乘 | |
| let dendInput = dendInputVec.view([| gcell.dendriteConn.shape.[1]; 1L |]) | |
| // 矩陣乘法: [axonCount, dendriteCount] x [dendriteCount, 1] → [axonCount, 1] | |
| let axonOut = | |
| torch.mm( | |
| gcell.synapseMap.to_type(torch.float32), | |
| dendInput.to_type(torch.float32) | |
| ).squeeze() | |
| // 指定輸出軸突 | |
| let outSel = torch.index_select(axonOut, 0, gcell.outputs) | |
| //printfn "Tick %d Group %d Out(mean)=%.4f var=%.4f" | |
| // t gi (outSel.mean().item<float>()) (outSel.var().item<float>()) | |
| printfn "Tick %d Group %d Out(mean)=%.4f var=%.4f" | |
| t gi | |
| ((outSel.mean().item<float32>())) | |
| ((outSel.var().item<float32>())) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment