Skip to content

Instantly share code, notes, and snippets.

@Yu212
Last active June 25, 2025 04:22
Show Gist options
  • Save Yu212/c00d5e6554f4b32934546947f46a3a06 to your computer and use it in GitHub Desktop.
Save Yu212/c00d5e6554f4b32934546947f46a3a06 to your computer and use it in GitHub Desktop.
01整数計画問題のソルバ
from typing import List, Tuple, Optional
import numpy as np
from numpy import linalg
import heapq
import time
EPS = 1e-8
class Instance:
def __init__(self, n, m, ans, a, b, c):
self.n = n
self.m = m
self.ans = ans
self.a = a
self.b = b
self.c = c
def params(self) -> Tuple[int, int, np.ndarray, np.ndarray, np.ndarray]:
return self.n, self.m, self.a, self.b, self.c
def load_instances() -> List[Instance]:
instances = []
with open("tests/case_all.txt", "r") as f:
t = int(f.readline())
for _ in range(t):
n, m, ans = f.readline().split()
n, m = map(int, (n, m))
ans = float(ans)
c = np.array(list(map(float, f.readline().split())))
a = np.array([list(map(int, f.readline().split())) for _ in range(m)])
b = np.array(list(map(int, f.readline().split())))
instances.append(Instance(n, m, ans, a, b, c))
f.readline()
return instances
def simplex(a: np.ndarray, b: np.ndarray, c: np.ndarray) -> np.ndarray:
m, n = a.shape
c = np.hstack((c, np.zeros(m)))
a = np.hstack((a, np.identity(m)))
basis = list(range(n, n+m))
non_basis = list(range(n))
while True:
y = linalg.solve(a[:, basis].T, c[basis])
cc = c[non_basis] - y @ a[:, non_basis]
if np.all(cc < EPS):
x = np.zeros(n + m)
x[basis] = linalg.solve(a[:, basis], b)
return x[:n]
s = np.argmax(cc)
d = linalg.solve(a[:, basis], a[:, non_basis[s]])
bb = linalg.solve(a[:, basis], b)
valid = d > EPS
ratios = bb[valid] / d[valid]
r = np.where(valid)[0][np.argmin(ratios)]
non_basis[s], basis[r] = basis[r], non_basis[s]
def solve_relaxed(instance: Instance, fixed: List[int], fixed_x: List[int]) -> Optional[np.ndarray]:
n, m, a, b, c = instance.params()
mask = ~np.isin(np.arange(n), fixed)
nb = b - a[:, fixed] @ fixed_x
if np.any(nb < 0):
return None
na = a[:, mask]
nc = c[mask]
nm, nn = na.shape
x = np.zeros(n)
nx = simplex(np.vstack((na, np.identity(nn))), np.hstack((nb, np.ones(nn))), nc)
x[mask] = nx
x[fixed] = fixed_x
assert np.all(-EPS <= x) and np.all(x <= 1 + EPS)
return x
def diving(instance: Instance, first_x: Optional[np.ndarray], fixed: List[int], fixed_x: List[int], limit: int) -> Tuple[Optional[np.ndarray], Optional[float]]:
n, m, a, b, c = instance.params()
best_x = None
best_z = None
for iter in range(limit):
if iter == 0 and first_x is not None:
x = first_x
else:
x = solve_relaxed(instance, fixed, fixed_x)
if x is None:
break
rounded_x = np.round(x)
rounded_z = rounded_x @ c
rounded_feasible = np.all(a @ rounded_x <= b)
if rounded_feasible and (best_z is None or best_z < rounded_z):
best_z = rounded_z
best_x = rounded_x
if np.all(np.abs(rounded_x - x) < EPS):
break
fractional = (EPS < x) & (x < 1 - EPS)
i = np.where(fractional)[0][(np.abs(x - 0.5) / (c + EPS))[fractional].argmin()]
fixed.append(i)
fixed_x.append(rounded_x[i])
return best_x, best_z
def greedy(instance: Instance) -> Tuple[np.ndarray, float]:
n, m, a, b, c = instance.params()
b_rem = b.copy()
x = np.zeros(n)
for i in np.argsort(c)[::-1]:
col = a[:, i]
if np.all(col <= b_rem):
x[i] = 1
b_rem -= col
return x, x @ c
from gurobipy import Model, GRB, quicksum
def solve_with_gurobi(instance: Instance) -> list[int]:
n, m, a, b, c = instance.params()
m, n = a.shape
model = Model()
model.setParam("OutputFlag", 0)
x = model.addVars(n, vtype=GRB.BINARY, name="x")
obj = quicksum(c[j] * x[j] for j in range(n))
model.setObjective(obj, GRB.MAXIMIZE)
for i in range(m):
expr = quicksum(a[i, j] * x[j] for j in range(n))
model.addConstr(expr <= b[i], name=f"c_{i}")
model.optimize()
return [round(x[j].X) for j in range(n)]
def solve(instance: Instance, logging: bool = False) -> np.ndarray:
n, m, a, b, c = instance.params()
greedy_x, greedy_z = greedy(instance)
incumbent_x, incumbent_z = greedy_x, greedy_z
dive_x, dive_z = diving(instance, None, [], [], 5)
if dive_z is not None and incumbent_z < dive_z:
incumbent_x, incumbent_z = dive_x, dive_z
que = [(0, [], [])]
while que:
_, fixed, fixed_x = heapq.heappop(que)
x = solve_relaxed(instance, fixed, fixed_x)
if x is None:
if logging: print("infeasible", fixed)
continue
z = x @ c
if z < incumbent_z:
if logging: print("< incumbent", fixed)
continue
dive_x, dive_z = diving(instance, x, fixed.copy(), fixed_x.copy(), 1)
if dive_z is not None and incumbent_z < dive_z:
if logging: print("incumbent update", dive_z)
incumbent_x, incumbent_z = dive_x, dive_z
fractional = (EPS < x) & (x < 1 - EPS)
if not np.any(fractional):
if logging: print("feasible", fixed)
continue
i = np.where(fractional)[0][(np.abs(x - 0.5) / (c + EPS))[fractional].argmin()]
heapq.heappush(que, (-z, fixed + [i], fixed_x + [0]))
heapq.heappush(que, (-z, fixed + [i], fixed_x + [1]))
return np.round(incumbent_x).astype(int).tolist()
instances = load_instances()
for instance in instances:
start = time.time()
for _ in range(10):
x = solve(instance, logging=False)
print(f"{(time.time() - start) / 10:.4f} sec, {x @ instance.c}, {x}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment