Last active
July 15, 2019 12:19
-
-
Save pocari/c6d295a3066b870c68e637d8225e3514 to your computer and use it in GitHub Desktop.
トロピカルruby
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# https://qiita.com/lotz/items/094bffd77b24e37bf20e | |
# 動的計画法を実現する代数〜トロピカル演算でグラフの最短経路を計算する〜 | |
# | |
# をrubyでそれっぽく書いてみる | |
class Semiring | |
attr_reader :value | |
def initialize(v) | |
@value = v | |
end | |
def oplus(other) | |
raise NotImplementedError | |
end | |
def otimes(other) | |
raise NotImplementedError | |
end | |
def self.zero | |
raise NotImplementedError | |
end | |
def self.one | |
raise NotImplementedError | |
end | |
def ==(other) | |
value == other.value | |
end | |
end | |
class Tropical < Semiring | |
INFTY = new(Float::INFINITY) | |
ONE = new(0) | |
def oplus(other) | |
[self, other].min_by(&:value) | |
end | |
def otimes(other) | |
if self == INFTY || other == INFTY | |
INFTY | |
else | |
Tropical.new(self.value + other.value) | |
end | |
end | |
def self.zero | |
INFTY | |
end | |
def self.one | |
ONE | |
end | |
def inspect | |
self == INFTY ? "Inf" : "T #{value}" | |
end | |
end | |
def t(v) | |
Tropical.new(v) | |
end | |
def inf | |
Tropical::INFTY | |
end | |
def ident(n, clazz) | |
n.times.map do |r| | |
n.times.map do |c| | |
if r == c | |
clazz.one | |
else | |
clazz.zero | |
end | |
end | |
end | |
end | |
def matrix_plus(a, b) | |
a.zip(b).map do |aa, bb| | |
aa.zip(bb).map do |aaa, bbb| | |
aaa.oplus(bbb) | |
end | |
end | |
end | |
def dot(a, b) | |
a.zip(b).map { |a1, b1| | |
a1.otimes(b1) | |
}.inject { |a2, b2| | |
a2.oplus(b2) | |
} | |
end | |
def distance_product(a, b) | |
raise "size missmatch #{a.size}x#{a[0].size} #{b.size}x#{b[0].size}" if a[0].size != b.size | |
bt = b.transpose | |
a.map do |a1| | |
bt.map do |b1| | |
dot(a1, b1) | |
end | |
end | |
end | |
def power(m, n) | |
if n == 0 | |
ident(m.size, m[0][0].class) | |
else | |
distance_product(m, power(m, n - 1)) | |
end | |
end | |
def matrix_dump(m) | |
m.each do |r| | |
p r | |
end | |
puts | |
end | |
class Edge | |
attr_reader :from, :to, :weight | |
def initialize(from, to, weight) | |
@from, @to, @weight = from, to, weight | |
end | |
def inspect | |
"#{from} -> #{to}" | |
end | |
end | |
class Path < Semiring | |
PROHIBITED = new(nil) | |
EMPTY = new([]) | |
def oplus(other) | |
[self, other].min_by {|e| e.path_weight.value} | |
end | |
def otimes(other) | |
if self == PROHIBITED || other == PROHIBITED | |
PROHIBITED | |
else | |
Path.new([*value, *other.value]).normalize | |
end | |
end | |
def self.zero | |
PROHIBITED | |
end | |
def self.one | |
EMPTY | |
end | |
def path_weight | |
if self == PROHIBITED | |
Tropical::INFTY | |
else | |
self.value.map(&:weight).inject(Tropical.one){|a, b| a.otimes(b)} | |
end | |
end | |
def inspect | |
self == PROHIBITED ? "Prohibited" : "[#{dump_route}]" | |
#self == PROHIBITED ? "Prohibited" : "[#{value}]" | |
end | |
def dump_route | |
if self != PROHIBITED | |
ret = value.inject do |a, b| | |
Edge.new("#{a.from} -> #{a.to}", b.to, Tropical.new(a.weight.value + b.weight.value)) | |
end | |
ret ? "#{ret.from} -> #{ret.to} (#{ret.weight.value})" : "" | |
end | |
end | |
def normalize | |
if self == PROHIBITED | |
PROHIBITED | |
else | |
normalized = value.select do |edge| | |
edge.weight != edge.weight.class.one | |
end | |
Path.new(normalized) | |
end | |
end | |
end | |
def path(edge) | |
Path.new([edge]) | |
end | |
def prohibited | |
Path::PROHIBITED | |
end | |
ab = Edge.new('a', 'b', t(2)) | |
ac = Edge.new('a', 'c', t(4)) | |
bb = Edge.new('b', 'b', t(0)) | |
bc = Edge.new('b', 'c', t(1)) | |
bd = Edge.new('b', 'd', t(9)) | |
cd = Edge.new('c', 'd', t(5)) | |
da = Edge.new('d', 'a', t(3)) | |
path_matrix = [ | |
[prohibited, path(ab), path(ac), prohibited], | |
[prohibited, path(bb), path(bc), path(bd)], | |
[prohibited, prohibited, prohibited, path(cd)], | |
[path(da), prohibited, prohibited, prohibited], | |
] | |
matrix_dump(power(matrix_plus(path_matrix, ident(4, Path)), 3)) | |
#=> | |
# [[], [a -> b (2)], [a -> b -> c (3)], [a -> b -> c -> d (8)]] | |
# [[b -> c -> d -> a (9)], [], [b -> c (1)], [b -> c -> d (6)]] | |
# [[c -> d -> a (8)], [c -> d -> a -> b (10)], [], [c -> d (5)]] | |
# [[d -> a (3)], [d -> a -> b (5)], [d -> a -> b -> c (6)], []] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment