Created
January 9, 2021 08:58
-
-
Save ochaton/b50bcb5af8911f666c4510dc6bd3ce95 to your computer and use it in GitHub Desktop.
Dijkstra implementation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| local type Distance = number | |
| local type VertexId = number | |
| local record Node | |
| preds: {VertexId:Distance} | |
| succs: {VertexId:Distance} | |
| end | |
| local node_mt: metatable<Node> = { | |
| __index = Node, | |
| } | |
| function Node:new(): Node | |
| return setmetatable({ preds = {}, succs = {} } as Node, node_mt) | |
| end | |
| local record Graph<T> | |
| _id: VertexId | |
| _gen_id: function(self: Graph<T>): VertexId | |
| kvn: {T:VertexId} | |
| nvk: {VertexId:T} | |
| nodes: {VertexId:Node} | |
| end | |
| Graph._id = 0 | |
| local graph_mt: metatable<Graph> = { | |
| __index = Graph, | |
| } | |
| function Graph:new<T>(): Graph<T> | |
| return setmetatable({ nodes = {}, kvn = {}, nvk = {} } as Graph<T>, graph_mt) | |
| end | |
| function Graph:_gen_id(): VertexId | |
| self._id = self._id + 1 | |
| return self._id | |
| end | |
| function Graph:addnode<T>(node: T) | |
| if self.kvn[node] then | |
| error("Node already exists", 2) | |
| end | |
| local id = self:_gen_id() | |
| self.kvn[node] = id | |
| self.nvk[id] = node | |
| self.nodes[id] = Node:new() | |
| end | |
| function Graph:addedge<T>(s: T, t: T, w: Distance|nil): Graph<T> | |
| local sid = assert(self.kvn[s], "Node "..tostring(s).." was not found") | |
| local tid = assert(self.kvn[t], "Node "..tostring(t).." was not found") | |
| w = w or 1 | |
| self.nodes[sid].succs[tid] = w | |
| self.nodes[tid].preds[sid] = w | |
| return self | |
| end | |
| function Graph:addbedge<T>(s: T, t: T, w: Distance|nil): Graph<T> | |
| return self:addedge(s, t, w):addedge(t, s, w) | |
| end | |
| -- TODO: this must be global? | |
| global record Dijkstra<T> | |
| path: {number:{T,Distance}} | |
| dist: {T:Distance} | |
| end | |
| local Heap = require("heap") | |
| function Graph:dijkstra<T>(s: T, t: T): Dijkstra<T> | |
| local p: {VertexId:VertexId} = {} | |
| local d: {VertexId:Distance} = {} | |
| for v in pairs(self.nodes) do | |
| d[v] = math.huge | |
| end | |
| local bh = Heap:new(function(a: VertexId, b: VertexId): boolean | |
| return d[a] < d[b] | |
| end) | |
| local seen: {VertexId:boolean} = {} | |
| d[self.kvn[s]] = 0 | |
| bh:push(self.kvn[s]) | |
| repeat | |
| local v = bh:pop() | |
| seen[v] = true | |
| for succ, w in pairs(self.nodes[v].succs) do | |
| if not seen[succ] and d[v]+w < d[succ] then | |
| d[succ] = d[v]+w | |
| p[succ] = v | |
| bh:push(succ) | |
| end | |
| end | |
| until bh:empty() | |
| local ret : Dijkstra<T> = { | |
| path = {}, | |
| dist = {}, | |
| } | |
| for vid, len in pairs(d) do | |
| local v = self.nvk[vid] | |
| ret.dist[v] = len | |
| end | |
| local tid = self.kvn[t] | |
| while tid do | |
| local pid = p[tid] | |
| if not pid then break end | |
| table.insert(ret.path, 1, { | |
| self.nvk[tid], | |
| self.nodes[tid].preds[pid], | |
| }) | |
| tid = pid | |
| end | |
| return ret | |
| end | |
| if not ... then | |
| local g = Graph:new() as Graph<string> | |
| for i = 1, 5 do | |
| g:addnode(tostring(i)) | |
| end | |
| g:addedge("1", "2", 100) | |
| g:addedge("2", "1", 100) | |
| g:addedge("2", "4", 1) | |
| g:addedge("4", "2", 1) | |
| g:addedge("4", "5", 1) | |
| g:addedge("5", "4", 1) | |
| g:addedge("5", "3", 1) | |
| g:addedge("3", "5", 1) | |
| g:addedge("3", "1", 1) | |
| g:addedge("1", "3", 1) | |
| local d = g:dijkstra("1", "2") | |
| for _, v in pairs(d.path) do | |
| print(v[1], v[2]) | |
| end | |
| for v, len in pairs(d.dist) do | |
| print(("dist(1, %q) = %s"):format(v, len)) | |
| end | |
| end | |
| return Graph |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| -- file: heap.tl (teal language) | |
| local function _parent(x: number): number | |
| return math.floor(x/2) | |
| end | |
| local function _left(x: number): number | |
| return 2*x | |
| end | |
| local function _right(x: number): number | |
| return 2*x+1 | |
| end | |
| local record Heap<TComparable> | |
| new: function<TComparable>(self: Heap, comp: function<TComparable>(TComparable, TComparable): boolean): Heap<TComparable> | |
| push: function<TComparable>(self: Heap<TComparable>, node: TComparable) | |
| pop: function<TComparable>(self: Heap<TComparable>): TComparable|nil | |
| -- Helper methods: | |
| siftUp: function<TComparable>(self: Heap<TComparable>, p: number) | |
| siftDown: function<TComparable>(self: Heap<TComparable>, p: number) | |
| less: function<TComparable>(self: Heap<TComparable>, a: number, b: number): boolean | |
| swap: function<TComparable>(self: Heap<TComparable>, a: number, b: number) | |
| n: number | |
| nodes: {TComparable} | |
| comp: function<TComparable>(TComparable, TComparable): boolean | |
| end | |
| function Heap:new<TComparable>(comp: function<TComparable>(TComparable, TComparable):boolean): Heap<TComparable> | |
| return setmetatable({ nodes = {}, n = 0, comp = comp } as Heap<TComparable>, { __index = self }) | |
| end | |
| function Heap:push<TComparable>(node: TComparable) | |
| self.n = self.n+1 | |
| self.nodes[self.n] = node | |
| self:siftUp(self.n) | |
| end | |
| function Heap:pop<TComparable>(): TComparable | |
| if self.n == 0 then | |
| return nil | |
| end | |
| local ret: TComparable | |
| ret, self.nodes[1] = self.nodes[1], self.nodes[self.n] | |
| self.n = self.n-1 | |
| self:siftDown(1) | |
| return ret | |
| end | |
| function Heap:len(): number | |
| return self.n | |
| end | |
| function Heap:empty(): boolean | |
| return self.n == 0 | |
| end | |
| function Heap:siftUp<TComparable>(p: number) | |
| while p > 1 do | |
| local parent = _parent(p) | |
| if self:less(p, parent) then | |
| self:swap(p, parent) | |
| else | |
| return | |
| end | |
| p = parent | |
| end | |
| end | |
| function Heap:siftDown<TComparable>(p: number) | |
| while _left(p) <= self.n do | |
| local left, right = _left(p), _right(p) | |
| local j = left | |
| if right <= self.n and self:less(right, left) then | |
| j = right | |
| end | |
| if self:less(p, j) then | |
| return | |
| end | |
| self:swap(p, j) | |
| p = j | |
| end | |
| end | |
| function Heap:less<TComparable>(a: number, b:number): boolean | |
| return self.comp(self.nodes[a], self.nodes[b]) | |
| end | |
| function Heap:swap<TComparable>(a: number, b: number) | |
| self.nodes[a], self.nodes[b] = self.nodes[b], self.nodes[a] | |
| end | |
| return Heap |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment