Skip to content

Instantly share code, notes, and snippets.

@VelizarHristov
Last active May 12, 2017 14:54
Show Gist options
  • Save VelizarHristov/7c3c6449fc0282f9cc6d3b6cc75a6a9f to your computer and use it in GitHub Desktop.
Save VelizarHristov/7c3c6449fc0282f9cc6d3b6cc75a6a9f to your computer and use it in GitHub Desktop.
import sys
from copy import deepcopy
from itertools import chain
sys.setrecursionlimit(100000000)
string_lines = sys.stdin.read().split('\n')[:-1]
lines = [list(x) for x in map(lambda line: map(int, line.split(' ')), string_lines)]
n, k = lines.pop(0)
[m] = lines.pop(0)
n_aug = n + 1
edges = [[] for _ in range(n_aug)]
for u_i, v_i, c_i in lines:
if u_i == v_i:
continue
exists = False
for i in range(len(edges[u_i - 1])):
v, c = edges[u_i - 1][i]
if v == v_i - 1:
exists = True
edges[u_i - 1][i] = (v_i - 1, min(c, c_i))
if not exists:
edges[u_i - 1].append((v_i - 1, c_i))
edges[n] = [(v_i, 0) for v_i in range(n)]
"""
lower_bound = [[] for _ in range(n)]
upper_bound = deepcopy(lower_bound)
# input: [2, 4, 5], 0, 13
# output: [2, 2, 2, 4, 4, 5, end...]
# input: [-11, -9, -7], -13, 0
# output: [-11, -11, -11, -9, -9, -7, -7, end...]
# input: [5], 4, 6
# output: [5, 5, end]
def scan_bound(ls, start, end):
if start > end:
return []
elif ls == []:
return [end] + scan_bound([], start + 1, end)
elif start > ls[0]:
return scan_bound(ls[1:], start, end)
else:
return [ls[0]] + scan_bound(ls, start + 1, end)
def neg_sorted_list(ls):
return list(map(lambda x: -x, reversed(ls)))
for pos in range(n):
sorted_edges = sorted(map(lambda x: x[0], edges[pos]))
smaller = []
larger = []
for x in sorted_edges:
(smaller if x < pos else larger).append(x)
lb_scan = scan_bound(smaller, 0, n - 1)
lower_bound[pos] = lb_scan
ub_scan = neg_sorted_list(scan_bound(neg_sorted_list(larger), -(n - 1), 0))
upper_bound[pos] = ub_scan
def tight_solve(lb, ub, pos, k):
if k == 0:
return 0
tight_lb = lower_bound[pos][lb]
tight_ub = upper_bound[pos][ub]
return solve(tight_lb, tight_ub, pos, k)
"""
smaller_edges = [[[(next, cost) for (next, cost) in edges[pos] if lb <= next < pos] for lb in range(n_aug)] for pos in range(n_aug)]
larger_edges = [[[(next, cost) for (next, cost) in edges[pos] if pos < next <= ub] for ub in range(n_aug)] for pos in range(n_aug)]
memo = {}
# input: lower_bound, upper_bound, current_position, remaining_moves
def solve(lb, ub, pos, k):
if k == 0:
return 0
memo_res = memo.get((lb, ub, pos, k))
if memo_res:
return memo_res
costs_1 = (cost + solve(lb, pos - 1, next, k - 1) for (next, cost) in smaller_edges[pos][lb])
costs_2 = (cost + solve(pos + 1, ub, next, k - 1) for (next, cost) in larger_edges[pos][ub])
min_ans = min(chain(costs_1, costs_2, [float('inf')]))
memo[(lb, ub, pos, k)] = min_ans
return min_ans
ans = solve(0, n - 1, n, k)
if ans == float('inf'):
print('-1', end='')
else:
print(ans, end='')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment