Skip to content

Instantly share code, notes, and snippets.

@ochaton
Created January 9, 2021 08:58
Show Gist options
  • Select an option

  • Save ochaton/b50bcb5af8911f666c4510dc6bd3ce95 to your computer and use it in GitHub Desktop.

Select an option

Save ochaton/b50bcb5af8911f666c4510dc6bd3ce95 to your computer and use it in GitHub Desktop.
Dijkstra implementation
local type Distance = number
local type VertexId = number
local record Node
preds: {VertexId:Distance}
succs: {VertexId:Distance}
end
local node_mt: metatable<Node> = {
__index = Node,
}
function Node:new(): Node
return setmetatable({ preds = {}, succs = {} } as Node, node_mt)
end
local record Graph<T>
_id: VertexId
_gen_id: function(self: Graph<T>): VertexId
kvn: {T:VertexId}
nvk: {VertexId:T}
nodes: {VertexId:Node}
end
Graph._id = 0
local graph_mt: metatable<Graph> = {
__index = Graph,
}
function Graph:new<T>(): Graph<T>
return setmetatable({ nodes = {}, kvn = {}, nvk = {} } as Graph<T>, graph_mt)
end
function Graph:_gen_id(): VertexId
self._id = self._id + 1
return self._id
end
function Graph:addnode<T>(node: T)
if self.kvn[node] then
error("Node already exists", 2)
end
local id = self:_gen_id()
self.kvn[node] = id
self.nvk[id] = node
self.nodes[id] = Node:new()
end
function Graph:addedge<T>(s: T, t: T, w: Distance|nil): Graph<T>
local sid = assert(self.kvn[s], "Node "..tostring(s).." was not found")
local tid = assert(self.kvn[t], "Node "..tostring(t).." was not found")
w = w or 1
self.nodes[sid].succs[tid] = w
self.nodes[tid].preds[sid] = w
return self
end
function Graph:addbedge<T>(s: T, t: T, w: Distance|nil): Graph<T>
return self:addedge(s, t, w):addedge(t, s, w)
end
-- TODO: this must be global?
global record Dijkstra<T>
path: {number:{T,Distance}}
dist: {T:Distance}
end
local Heap = require("heap")
function Graph:dijkstra<T>(s: T, t: T): Dijkstra<T>
local p: {VertexId:VertexId} = {}
local d: {VertexId:Distance} = {}
for v in pairs(self.nodes) do
d[v] = math.huge
end
local bh = Heap:new(function(a: VertexId, b: VertexId): boolean
return d[a] < d[b]
end)
local seen: {VertexId:boolean} = {}
d[self.kvn[s]] = 0
bh:push(self.kvn[s])
repeat
local v = bh:pop()
seen[v] = true
for succ, w in pairs(self.nodes[v].succs) do
if not seen[succ] and d[v]+w < d[succ] then
d[succ] = d[v]+w
p[succ] = v
bh:push(succ)
end
end
until bh:empty()
local ret : Dijkstra<T> = {
path = {},
dist = {},
}
for vid, len in pairs(d) do
local v = self.nvk[vid]
ret.dist[v] = len
end
local tid = self.kvn[t]
while tid do
local pid = p[tid]
if not pid then break end
table.insert(ret.path, 1, {
self.nvk[tid],
self.nodes[tid].preds[pid],
})
tid = pid
end
return ret
end
if not ... then
local g = Graph:new() as Graph<string>
for i = 1, 5 do
g:addnode(tostring(i))
end
g:addedge("1", "2", 100)
g:addedge("2", "1", 100)
g:addedge("2", "4", 1)
g:addedge("4", "2", 1)
g:addedge("4", "5", 1)
g:addedge("5", "4", 1)
g:addedge("5", "3", 1)
g:addedge("3", "5", 1)
g:addedge("3", "1", 1)
g:addedge("1", "3", 1)
local d = g:dijkstra("1", "2")
for _, v in pairs(d.path) do
print(v[1], v[2])
end
for v, len in pairs(d.dist) do
print(("dist(1, %q) = %s"):format(v, len))
end
end
return Graph
-- file: heap.tl (teal language)
local function _parent(x: number): number
return math.floor(x/2)
end
local function _left(x: number): number
return 2*x
end
local function _right(x: number): number
return 2*x+1
end
local record Heap<TComparable>
new: function<TComparable>(self: Heap, comp: function<TComparable>(TComparable, TComparable): boolean): Heap<TComparable>
push: function<TComparable>(self: Heap<TComparable>, node: TComparable)
pop: function<TComparable>(self: Heap<TComparable>): TComparable|nil
-- Helper methods:
siftUp: function<TComparable>(self: Heap<TComparable>, p: number)
siftDown: function<TComparable>(self: Heap<TComparable>, p: number)
less: function<TComparable>(self: Heap<TComparable>, a: number, b: number): boolean
swap: function<TComparable>(self: Heap<TComparable>, a: number, b: number)
n: number
nodes: {TComparable}
comp: function<TComparable>(TComparable, TComparable): boolean
end
function Heap:new<TComparable>(comp: function<TComparable>(TComparable, TComparable):boolean): Heap<TComparable>
return setmetatable({ nodes = {}, n = 0, comp = comp } as Heap<TComparable>, { __index = self })
end
function Heap:push<TComparable>(node: TComparable)
self.n = self.n+1
self.nodes[self.n] = node
self:siftUp(self.n)
end
function Heap:pop<TComparable>(): TComparable
if self.n == 0 then
return nil
end
local ret: TComparable
ret, self.nodes[1] = self.nodes[1], self.nodes[self.n]
self.n = self.n-1
self:siftDown(1)
return ret
end
function Heap:len(): number
return self.n
end
function Heap:empty(): boolean
return self.n == 0
end
function Heap:siftUp<TComparable>(p: number)
while p > 1 do
local parent = _parent(p)
if self:less(p, parent) then
self:swap(p, parent)
else
return
end
p = parent
end
end
function Heap:siftDown<TComparable>(p: number)
while _left(p) <= self.n do
local left, right = _left(p), _right(p)
local j = left
if right <= self.n and self:less(right, left) then
j = right
end
if self:less(p, j) then
return
end
self:swap(p, j)
p = j
end
end
function Heap:less<TComparable>(a: number, b:number): boolean
return self.comp(self.nodes[a], self.nodes[b])
end
function Heap:swap<TComparable>(a: number, b: number)
self.nodes[a], self.nodes[b] = self.nodes[b], self.nodes[a]
end
return Heap
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment