Skip to content

Instantly share code, notes, and snippets.

@jakab922
Created May 14, 2018 14:19
Show Gist options
  • Save jakab922/fc3d23861b5f723ded2c3c279056f7e3 to your computer and use it in GitHub Desktop.
Save jakab922/fc3d23861b5f723ded2c3c279056f7e3 to your computer and use it in GitHub Desktop.
from collections import defaultdict as dd
# 2SAT linear solver START
def strongly_connected(graph):
""" From Cormen: Strongly connected components chapter. """
l = len(graph)
normal = [list(graph[i]) for i in xrange(l)]
reverse = [[] for _ in xrange(l)]
for fr, tos in enumerate(normal):
for to in tos:
reverse[to].append(fr)
finishing = [0 for _ in xrange(l)]
was = [False] * l
cf = 0
for i in xrange(l):
if was[i]:
continue
was[i] = True
stack = [(i, 0)]
while stack:
curr, index = stack.pop()
if index == len(normal[curr]):
finishing[curr] = cf
cf += 1
continue
stack.append((curr, index + 1))
nxt = normal[curr][index]
if not was[nxt]:
was[nxt] = True
stack.append((nxt, 0))
order = sorted(range(l), key=lambda i: finishing[i], reverse=True)
cc = 0
components = [-1] * l
was = [False] * l
for el in order:
if was[el]:
continue
was[el] = True
stack = [el]
components[el] = cc
while stack:
fr = stack.pop()
for to in reverse[fr]:
if was[to]:
continue
components[to] = cc
was[to] = True
stack.append(to)
cc += 1
return components
def calc_strong_graph(graph, components):
n = max(components) + 1
ret = [[] for _ in xrange(n)]
for fr, tos in enumerate(graph):
for to in tos:
cfr = components[fr]
cto = components[to]
if cfr != cto:
ret[cfr].append(cto)
return ret
def topological_order(graph):
n = len(graph)
need = [0] * n
for tos in graph:
for to in tos:
need[to] += 1
stack = [i for i in xrange(n) if need[i] == 0]
order = 0
ret = [None] * n
while stack:
fr = stack.pop()
ret[fr] = order
order += 1
for to in graph[fr]:
need[to] -= 1
if need[to] == 0:
stack.append(to)
return ret
def two_sat_solve(conj_normal_form):
""" Expects the input as a list of tuples
each having 2 elements. If an element is
i then we assume x_i and if it's -i
we assume ~x_i. For example if we have
(i, -j) it would mean x_i V ~x_j.
The solution is a dict, mapping the variable
index to True and False."""
# Transform the variables
fmap, rmap = {}, {}
c = 0
for els in conj_normal_form:
for el in els:
if abs(el) not in fmap:
fmap[abs(el)] = c
rmap[c] = abs(el)
c += 1
n = len(fmap)
for key in rmap.keys():
val = rmap[key]
rmap[key + n] = -val
fmap[-val] = key + n
# Create the associated graph
graph = [[] for _ in xrange(2 * n)]
for i, j in conj_normal_form:
for fr, to in ((-i, j), (-j, i)):
graph[fmap[fr]].append(fmap[to])
# Compute the connected components
components = strongly_connected(graph)
# Compute the reverse map from component to node
rev_map = dd(set)
for i, comp in enumerate(components):
rev_map[comp].add(i)
# Check if any variable is in the same component as its negative
for values in rev_map.itervalues():
for value in values:
if ((value + n) % (2 * n)) in values:
return False, {}
# Calculate the gaph consisting of the strong components
strong_graph = calc_strong_graph(graph, components)
ns = len(strong_graph)
# Get the topological order of the component graph
top_order = topological_order(strong_graph)
rev_order = sorted(range(ns), key=lambda i: top_order[i], reverse=True)
values = [None] * 2 * n
# Traverse the components in reverse topological order and
# fill out the variable accordingly.
for comp in rev_order:
first = next(iter(rev_map[comp]))
if values[first] is not None:
continue
other_comp = components[(first + n) % (2 * n)]
for el in rev_map[comp]:
values[el] = True
for el in rev_map[other_comp]:
values[el] = False
# Calculate the return value
ret = {}
for i in xrange(n):
ret[rmap[i]] = values[i]
return True, ret
# 2SAT linear solver END
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment