Skip to content

Instantly share code, notes, and snippets.

@ingted
Created November 5, 2025 15:22
Show Gist options
  • Save ingted/45e807bdef24fc2d1f6ec4be484db72c to your computer and use it in GitHub Desktop.
Save ingted/45e807bdef24fc2d1f6ec4be484db72c to your computer and use it in GitHub Desktop.
fsnn
#r "nuget: TorchSharp"
#r "nuget: TorchSharp-cuda-windows, 0.105.1"
open System
open TorchSharp
//open type TorchSharp.torch
let device = if torch.cuda_is_available() then torch.CUDA else torch.CPU
printfn "Running on %A" device
// ======= 參數設定 =======
let g = 3 // group 數量
let n = 200 // 平均 cell 數量
let m = 30 // ±波動
let srange = 1000 // 突觸區域大小 (索引範圍)
let j = 200 // 平均樹突軸突數量基準
let tickCount = 50 // tick 數量 (t1~tn)
let o = 10 // 指定輸出軸突數量
let rnd = Random()
// ======= 隨機生成各 group 結構 =======
type CellGroup =
{ dendriteConn : torch.Tensor // [cell, dendrites] 樹突連接 (稀疏索引)
axonConn : torch.Tensor // [cell, axons] 軸突連接
synapseMap : torch.Tensor // [axon, dendrite] 突觸連接圖
outputs : torch.Tensor } // 指定為輸出的軸突索引
let makeGroup (idx:int) =
let cellCount = rnd.Next(n - m, n + m)
let dendriteCount = rnd.Next(j - 50, j + 50)
let axonCount = rnd.Next(j - 50, j + 50)
// 樹突、軸突、突觸隨機連接
let dendriteConn = torch.randint(srange, [| cellCount; dendriteCount |], device=device)
let axonConn = torch.randint(srange, [| cellCount; axonCount |], device=device)
let synapseMap = torch.randint(2, [| axonCount; dendriteCount |], device=device)
// 指定部分軸突為輸出
let rp = torch.randperm(axonCount, device=device)
let rangeIdx : torch.TensorIndex = System.Range(Index.FromStart 0, Index.FromEnd(Math.Min(o, (axonCount-1))))
let outputs = rp[rangeIdx]
{ dendriteConn = dendriteConn
axonConn = axonConn
synapseMap = synapseMap
outputs = outputs }
// ======= 所有群組初始化 =======
let groups = [| for i in 0 .. g-1 -> makeGroup i |]
// ======= Tick 信號產生器 =======
// 每 tick 產生一個刺激向量 [srange],模擬外部輸入
let signalGenerator () =
torch.randn([| int64 srange |], device=device)
let t = 1
let gi = 0
// ======= Tick 更新模擬 =======
for t in 1 .. tickCount do
let signal = signalGenerator()
// 每組群組的突觸響應
for gi in 0 .. g-1 do
let gcell = groups[gi]
// 模擬樹突輸入總和
//let dendInput =
// torch.index_select(signal, 0, (gcell.dendriteConn.flatten()))
// .view(gcell.dendriteConn.shape)
// .sum(1, keepdim=true)
// 模擬樹突輸入總和
let gathered =
torch.index_select(signal, 0, gcell.dendriteConn.flatten())
.view([| gcell.dendriteConn.shape.[0]; gcell.dendriteConn.shape.[1] |])
// 對“cell 維”(dim=0) 做加總,得到每個 dendrite 的合併輸入 → [dendriteCount]
let dendInputVec = gathered.sum(0, keepdim=false)
// 轉成 [dendriteCount, 1] 才能與 synapseMap 相乘
let dendInput = dendInputVec.view([| gcell.dendriteConn.shape.[1]; 1L |])
// 矩陣乘法: [axonCount, dendriteCount] x [dendriteCount, 1] → [axonCount, 1]
let axonOut =
torch.mm(
gcell.synapseMap.to_type(torch.float32),
dendInput.to_type(torch.float32)
).squeeze()
// 指定輸出軸突
let outSel = torch.index_select(axonOut, 0, gcell.outputs)
//printfn "Tick %d Group %d Out(mean)=%.4f var=%.4f"
// t gi (outSel.mean().item<float>()) (outSel.var().item<float>())
printfn "Tick %d Group %d Out(mean)=%.4f var=%.4f"
t gi
((outSel.mean().item<float32>()))
((outSel.var().item<float32>()))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment