#!/usr/bin/env python3

""" Basic implementation of the Two-Phase Simplex Method """

import sys
import numpy as np

class Solver():

    def __init__(self, A, b, c):
        rows = A.shape[0]
        cols = A.shape[1]
        self.x = np.concatenate((np.zeros(cols), b))
        self.z = np.concatenate((-c, np.zeros(rows)))
        self.A = np.concatenate((A, np.identity(rows)), axis=1)
        self.beta = np.array([n+A.shape[1] for n in range(rows)])
        self.nu = np.array([n for n in range(cols)])

    def solve_primal(self):
        while True:
            # Step 1: Check for optimality
            z_nu_star = self.z[self.nu]
            if (z_nu_star >= 0).all(): break
            # Step 2: Select entering variable
            j = [j for j in self.nu if self.z[j] < 0][0]
            # Step 3: Compute Primal Step Direction
            B = self.A[:, self.beta]
            N = self.A[:, self.nu]
            e_j = np.zeros(len(self.nu))
            e_j[[i for i, j in enumerate(self.nu) if self.z[j] < 0][0]] = 1
            delta_x_beta = np.linalg.inv(B).dot(N).dot(e_j)
            # Step 4: Compute Primal Step Length
            x_beta_star = self.x[self.beta]
            t = (delta_x_beta / x_beta_star).max() ** -1
            if t <= 0: raise Exception('Primal is unbounded')
            # Step 5: Select leaving variable
            i = self.beta[np.argmax(delta_x_beta / x_beta_star)]
            # Step 6: Compute Dual Step Direction
            e_i = np.zeros(len(self.beta))
            e_i[np.argmax(delta_x_beta / x_beta_star)] = 1
            delta_z_nu = -np.linalg.inv(B).dot(N).transpose().dot(e_i)
            # Step 7: Compute Dual Step Length
            j2 = [i for i, j in enumerate(self.nu) if self.z[j] < 0][0]
            s = z_nu_star[j2] / delta_z_nu[j2]
            # Step 8: Update Current Primal and Dual Solutions
            x_beta_star = x_beta_star - t * delta_x_beta
            for c, index in enumerate(self.beta):
                self.x[index] = x_beta_star[c]
            self.x[j] = t
            z_nu_star = z_nu_star - s * delta_z_nu
            for c, index in enumerate(self.nu):
                self.z[index] = z_nu_star[c]
            self.z[i] = s
            # Step 9: Update basis
            beta_index = np.where(self.beta==i)
            nu_index = np.where(self.nu==j)
            temp = self.beta[beta_index]
            self.beta[beta_index] = self.nu[nu_index]
            self.nu[nu_index] = temp
        return self.x[:len(self.nu)]

    def solve_dual(self):
        while True:
            # Step 1: Check for optimality
            x_beta_star = self.x[self.beta]
            if (x_beta_star >= 0).all(): break
            # Step 2: Select entering variable
            i = [i for i in self.beta if self.x[i] < 0][0]
            # Step 3: Compute Dual Step Direction
            B = self.A[:, self.beta]
            N = self.A[:, self.nu]
            e_i = np.zeros(len(self.beta))
            e_i[[j for j, i in enumerate(self.beta) if self.x[i] < 0][0]] = 1
            delta_z_nu = -np.linalg.inv(B).dot(N).transpose().dot(e_i)
            # Step 4: Compute Dual Step Length
            z_nu_star = self.z[self.nu]
            s = (delta_z_nu / z_nu_star).max() ** -1
            if s <= 0: raise Exception('Dual is unbounded')
            # Step 5: Select leaving variable
            j = self.nu[np.argmax(delta_z_nu / z_nu_star)]
            # Step 6: Compute Primal Step Direction
            e_j = np.zeros(len(self.nu))
            e_j[np.argmax(delta_z_nu / z_nu_star)] = 1
            delta_x_beta = np.linalg.inv(B).dot(N).dot(e_j)
            # Step 7: Compute Primal Step Length
            j2 = [j for j, i in enumerate(self.beta) if self.x[i] < 0][0]
            t = x_beta_star[j2] / delta_x_beta[j2]
            # Step 8: Update Current Primal and Dual Solutions
            x_beta_star = x_beta_star - t * delta_x_beta
            for c, index in enumerate(self.beta):
                self.x[index] = x_beta_star[c]
            self.x[j] = t
            z_nu_star = z_nu_star - s * delta_z_nu
            for c, index in enumerate(self.nu):
                self.z[index] = z_nu_star[c]
            self.z[i] = s
            # Step 9: Update basis
            beta_index = np.where(self.beta==i)
            nu_index = np.where(self.nu==j)
            temp = self.beta[beta_index]
            self.beta[beta_index] = self.nu[nu_index]
            self.nu[nu_index] = temp
        return self.x[:len(self.nu)]

    def solve(self):
        # Check for current optimality
        if (self.x >= 0).all() and (self.z >= 0).all():
            return self.x[:len(self.beta)]
        # Primal feasible but not dual feasible
        if (self.x >= 0).all() and not (self.z >= 0).all():
            return self.solve_primal()
        # Not primal feasible but dual feasible
        if not (self.x >= 0).all() and (self.z >= 0).all():
            return self.solve_dual()
        # Neither primal nor dual feasible
        if not (self.x >= 0).all() and not (self.z >= 0).all():
            original_z = self.z
            self.z = np.abs(self.z)
            self.solve_dual()
            self.z = original_z
            return self.solve_primal()

if __name__ == '__main__':
    if len(sys.argv) != 4:
        print('Usage:', sys.argv[0], 'my_A.csv my_b.csv my_c.csv')
        sys.exit(1)
    A = np.genfromtxt(sys.argv[1], delimiter=',')
    b = np.genfromtxt(sys.argv[2], delimiter=',')
    c = np.genfromtxt(sys.argv[3], delimiter=',')
    solver = Solver(A, b, c)
    print(solver.solve())