Last active
May 30, 2020 13:59
-
-
Save kraftwerk28/d273b81319ff9a16717ea4a5d94a9ff3 to your computer and use it in GitHub Desktop.
Jordan-Gauss method
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
data/ | |
test*/ | |
iter*/ | |
*.txt | |
__pycache__/ | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import sys | |
from fractions import Fraction | |
from tabulate import tabulate | |
import math | |
from utils import * | |
class JG: | |
def __init__(self, input_file: str = None, is_maximizing=True): | |
self.iters = [] | |
self.headers = [] | |
self.free_variables = [] # s1, s2, s3 etc | |
self.target_row_idx = 0 # Z-equation row index (usually first row) | |
self.pivots = [] # Dumped pivot points just for logging | |
self.is_maximizing = is_maximizing # False -> min; True -> max | |
if input_file is None: | |
return | |
with open(input_file, 'rt') as f: | |
max_min, *mtx = [ | |
s[:-1].split() | |
for s in f.readlines() | |
if s.strip() | |
] | |
self.is_maximizing = max_min[0] == 'max' | |
has_header = mtx[0][0][0].isalpha() | |
if has_header: | |
self.headers = mtx[0] | |
self.free_variables = [ | |
h for h in mtx[0] | |
if not h.startswith('x') and h != 'P' | |
] | |
tdata = mtx[1:] | |
else: | |
tdata = mtx | |
parsed = [[fr(c) for c in row] for row in tdata] | |
self.iters.append(parsed) | |
def iterate(self): | |
lastiter = self.iters[-1] | |
piv = self._choose_pivot() | |
self.pivots.append( | |
f'pivot: {str(lastiter[piv[0]][piv[1]])} ({piv[0]}; {piv[1]})' | |
) | |
next_iter = iter_mtx(lastiter, piv) | |
self.iters.append(next_iter) | |
if self.free_variables and self.headers: # Swap variable labels | |
row, col = piv | |
t = self.free_variables[row - 1] | |
self.free_variables[row - 1] = self.headers[col] | |
self.headers[col] = t | |
def iterate_full(self): | |
self.print_iter(-1) | |
piv = self._choose_pivot() | |
iter_mtx(self.iters[-1], piv) | |
while self._check_iter_ended(): | |
self.iterate() | |
self.print_iter(-1) | |
def add_mtx(self, mtx): | |
self.iters.append(mtx) | |
def _check_iter_ended(self): | |
return any( | |
c < 0 if self.is_maximizing else c > 0 | |
for c in self.iters[-1][self.target_row_idx][:-1] | |
) | |
def print_iter(self, i=0): | |
num = i + 1 if i > -1 else len(self.iters) + i | |
print(f'\nIteration #{num}:') | |
if self.pivots: | |
print(self.pivots[i]) | |
headers = [''] + self.headers | |
sider = ['Z', *self.free_variables] | |
if self.free_variables: | |
tdata = [ | |
[var] + [str(c) for c in rest] | |
for var, rest in zip(sider, self.iters[i]) | |
] | |
else: | |
tdata = [ | |
[str(c) for c in row] | |
for row in self.iters[i] | |
] | |
print(tabulate(tdata, headers=headers, tablefmt='psql')) | |
def _choose_pivot(self) -> (int, int): | |
lastiter = self.iters[-1] | |
top_row = lastiter[self.target_row_idx][:-1] | |
col_idx = top_row.index( | |
min(top_row) if self.is_maximizing else max(top_row) | |
) | |
col = [row[col_idx] for row in lastiter] | |
rates = [ | |
(row[-1] / col[i]) | |
if (col[i] > 0 and i != self.target_row_idx) | |
else math.inf | |
for i, row in enumerate(lastiter) | |
] | |
row_idx = rates.index(min(rates)) | |
return (row_idx, col_idx) | |
def dump_matrices(self, dirpath): | |
os.makedirs(dirpath, exist_ok=True) | |
for i, it in enumerate(self.iters): | |
path = os.path.join(dirpath, f'iter{i}.txt') | |
with open(path, 'w+') as f: | |
data = '\n'.join( | |
''.join(str(it).rjust(6) | |
for it in row) for row in it | |
) | |
f.write(data + '\n') |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# Jordan-Gauss resolve method | |
# Usage: ./main.py input.txt output.txt | |
# Input file must contain matrix w/ whitespace-separated numbers | |
# and line-break separated rows | |
import sys | |
from JG import JG | |
if __name__ == '__main__': | |
if len(sys.argv[1:]) < 2: | |
sys.exit(1) | |
inp, outdir = sys.argv[1:] | |
jg = JG(inp, is_maximizing=True) | |
jg.iterate_full() | |
jg.dump_matrices(outdir) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from fractions import Fraction | |
def fr(value): return Fraction(value).limit_denominator(200) | |
def iter_mtx(mtx, pivot): # Does one iteration on matrix | |
rows, cols = len(mtx), len(mtx[0]) | |
res = [[None for _ in range(cols)] for _ in range(rows)] | |
py, px = pivot # Pivot location | |
pv = mtx[py][px] # Pivot value | |
if pv == 0: | |
raise Exception('Pivot cannot be zero') | |
for i in range(cols): # Divide whole row by pivot value | |
res[py][i] = fr(mtx[py][i] / pv) | |
for i in range(rows): # Zero-ify whole column on pivot value | |
res[i][px] = fr(0) | |
res[py][px] = fr(1) # Self-divide to pivot point | |
for y in range(rows): | |
if y == py: | |
continue | |
for x in range(cols): | |
a, b, c, d = mtx[y][x], mtx[y][px], mtx[py][x], pv | |
res[y][x] = fr((a * d - b * c) / pv) | |
return res |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment