Created
December 1, 2018 17:55
-
-
Save matthewfl/c7c80dd366c9e3dd79a2079380c7d505 to your computer and use it in GitHub Desktop.
Simple implementation of the Simplex Method
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from scipy.optimize import linprog | |
class simplex(object): | |
def __init__(self, A, b, c): | |
self.make_tableau(A, b, c) | |
self.nvars = c.shape[0] | |
# self.run_simplex() | |
def get_x(self): | |
x = np.zeros(self.nvars) | |
for i in range(self.nvars): | |
ix = np.where(self.tableau[1:, i+1] == 1)[0] | |
print(self.tableau[1:,i+1]) | |
if len(ix) != 1: | |
continue # too many values for this variable | |
if self.tableau[1:, i+1].sum() != 1: | |
continue # the row does not sum to 1 | |
x[i] = self.tableau[ix[0]+1, -1] | |
return x | |
def get_objective(self): | |
return self.tableau[0, -1] | |
def make_tableau(self, A, b, c): | |
nvars = c.shape[0] | |
nconst = b.shape[0] | |
assert A.shape == (nconst, nvars) | |
tableau = np.zeros((nconst + 1, nvars + nconst + 2)); | |
tableau[0,0] = 1 | |
tableau[0,1:(nvars+1)] = -c | |
tableau[1:, 1:(nvars+1)] = A | |
assert (tableau[1:, (nvars+1):-1] == 0).all() | |
tableau[1:, (nvars+1):-1] = np.eye(nconst) | |
tableau[1:, -1] = b | |
self.tableau = tableau | |
def run_simplex(self): | |
while True: | |
pivotColumn = self.tableau[0,1:].argmin() + 1 | |
# I think that this is the x>0 constraint, so need to instead have that | |
# want the variables to be between 1>x>-1 | |
# though then it might not be able to actually represent this basis??? | |
if self.tableau[0, pivotColumn] >= 0: | |
return | |
r = self.tableau[1:, -1] / (self.tableau[1:, pivotColumn] + .000001) | |
r[r <= 0] = np.inf | |
pivotRow = r.argmin() + 1 # choose the smallest positive result | |
print(pivotColumn, pivotRow) | |
# do the pivot | |
div = self.tableau[pivotRow, pivotColumn] | |
assert div != 0 | |
self.tableau[pivotRow, :] /= div | |
for i in range(self.tableau.shape[0]): | |
if i != pivotRow: | |
val = self.tableau[i, pivotColumn] | |
self.tableau[i,:] -= self.tableau[pivotRow, :] * val | |
def print_tableau(self): | |
s = '' | |
for i in range(self.tableau.shape[0]): | |
for j in range(self.tableau.shape[1]): | |
s += '{0:.4f}\t'.format(self.tableau[i,j]) | |
s += '\n' | |
print(s) | |
def main(): | |
num = 5 | |
dim = 5 | |
A = np.random.rand(num, dim) # the constraints | |
b = np.random.rand(num) # the distance from the basis hyperplanes that we are interested in | |
#b = np.zeros(num) | |
c = np.random.rand(dim) # the constraints that we are trying to maximize | |
print(linprog(method='simplex', A_ub=A, b_ub=b, c=-c)) | |
print('-'*20) | |
s = simplex(A, b, c) | |
s.print_tableau() | |
print('-'*10) | |
s.run_simplex() | |
s.print_tableau() | |
x = s.get_x() | |
print(s.get_objective(), x, x.dot(c)) | |
#print(s.tableau) | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment