Last active
March 19, 2017 20:06
-
-
Save pv/c87650ff0cff9a3b5710 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from __future__ import division, absolute_import, print_function | |
| import numpy as np | |
| import os | |
| import gzip | |
| import subprocess | |
| import time | |
| import joblib | |
| import matplotlib.pyplot as plt | |
| from cycler import cycler | |
| from scipy.sparse.linalg import gmres, lgmres, LinearOperator, spilu | |
| from scipy.sparse import rand, eye, diags | |
| from scipy.io import mmread | |
| try: | |
| from scipy.sparse.linalg import gcrotmk | |
| except ImportError: | |
| gcrotmk = None | |
| BASEDIR = os.path.join(os.path.abspath(os.path.dirname(__file__)), '_cache') | |
| mem = joblib.Memory(os.path.join(BASEDIR, 'cache')) | |
| def perf_prof(solver, A, b, max_matvec, max_time, **kw): | |
| print("PERF", solver) | |
| count = [0] | |
| b = b / np.linalg.norm(b) | |
| b_norm = 1.0 | |
| start = time.time() | |
| def matvec(x): | |
| count[0] += 1 | |
| last_x[0] = x.copy() | |
| return A.dot(x) | |
| def callback(v): | |
| if solver is gmres: | |
| r = v * b_norm | |
| else: | |
| r = np.linalg.norm(A.dot(v) - b) | |
| res.append((count[0], r/b_norm)) | |
| if count[0] > max_matvec or time.time() - start > max_time: | |
| raise StopIteration() | |
| x0 = np.zeros_like(b) | |
| last_x = [x0.copy()] | |
| res = [(0, 1.0)] | |
| if solver is None: | |
| return np.asarray(res) | |
| Aop = LinearOperator(matvec=matvec, shape=A.shape, dtype=A.dtype) | |
| try: | |
| x, info = solver(Aop, b, callback=callback, x0=x0, maxiter=100000, **kw) | |
| res.append((count[0], np.linalg.norm(A.dot(x) - b)/b_norm)) | |
| except StopIteration: | |
| pass | |
| return np.asarray(res) | |
| @mem.cache | |
| def perf(solver, problem, *a, **kw): | |
| A, b = load(problem) | |
| return perf_prof(solver, A, b, *a, **kw), A.shape | |
| def load(problem): | |
| f = _load('%s.mtx.gz' % problem) | |
| A = mmread(f).tocsr() | |
| f.close() | |
| try: | |
| f = _load('%s_rhs1.mtx.gz' % problem) | |
| b = np.array(mmread(f)).ravel() | |
| f.close() | |
| except IOError: | |
| np.random.seed(1234) | |
| b = np.random.rand(A.shape[0]) | |
| return A, b | |
| def _load(fn): | |
| repo = 'ftp://math.nist.gov/pub/MatrixMarket2/' | |
| remote_fn = repo + fn | |
| local_fn = os.path.join(BASEDIR, os.path.normpath(remote_fn[6:]).lstrip(os.path.sep)) | |
| missing_fn = local_fn + '.missing' | |
| if not os.path.isfile(local_fn) and not os.path.isfile(missing_fn): | |
| dn = os.path.dirname(local_fn) | |
| if not os.path.isdir(dn): | |
| os.makedirs(dn) | |
| print("RECV:", remote_fn) | |
| try: | |
| subprocess.check_call(['curl', '-o', local_fn, remote_fn]) | |
| except subprocess.CalledProcessError: | |
| with open(missing_fn, 'w') as f: | |
| pass | |
| if local_fn.endswith('.gz'): | |
| return gzip.open(local_fn, 'rb') | |
| else: | |
| return open(local_fn, 'rb') | |
| def main(): | |
| m = 30 | |
| max_matvec = 5000 | |
| max_time = 10 | |
| tol = 1e-12 | |
| solvers = [ | |
| (gmres, dict(restart=m), 'GMRES({restart})'), | |
| (gcrotmk, dict(m=m//2, k=m//4), 'GCROTMK({m},{k})'), | |
| (lgmres, dict(inner_m=m-3*3, outer_k=3), 'LGMRES({inner_m},{outer_k})'), | |
| (lgmres, dict(inner_m=m-3*3, outer_k=3, prepend_outer_v=True), | |
| 'LGMRESx({inner_m},{outer_k})') | |
| ] | |
| #n = 1000 | |
| #A = diags([-1, 2, 1], [-1, 0, 1], shape=(n, n), format='csr') | |
| #A[0,0] = 123 | |
| #b = np.random.rand(n) | |
| #problem = "SPARSKIT/drivcav/e05r0100" | |
| #problem = "SPARSKIT/drivcav/e05r0200" | |
| #problem = "SPARSKIT/drivcav/e30r2000" | |
| #problem = "SPARSKIT/drivcav/e40r0500" | |
| #problem = "Harwell-Boeing/sherman/sherman1" | |
| #problem = "Harwell-Boeing/sherman/sherman5" | |
| #problem = "misc/hamm/add32" | |
| #problem = "Harwell-Boeing/bcsstruc3/bcsstk25" | |
| #problem = "misc/qcd/conf5.4-00l8x8-0500" | |
| problems = [ | |
| "misc/hamm/add20", | |
| "Harwell-Boeing/oilgen/orsreg_1", | |
| "Harwell-Boeing/oilgen/orsirr_1", | |
| "NEP/mvmmcd/cdde2", | |
| "NEP/matpde/pde900", | |
| "NEP/brussel/rdb1250", | |
| "Harwell-Boeing/sherman/sherman1", | |
| "SPARSKIT/drivcav_old/cavity10", | |
| ] | |
| plt.clf() | |
| plt.gcf().set_size_inches(30/2.54, 15/2.54) | |
| nrow = int(np.sqrt(0.75*len(problems))) | |
| ncol, rem = divmod(len(problems), nrow) | |
| while rem > 0: | |
| ncol += 1 | |
| rem -= nrow | |
| for j, problem in enumerate(problems): | |
| jrow, jcol = divmod(j, ncol) | |
| plt.subplot(nrow, ncol, j+1) | |
| print("\n{}".format(problem)) | |
| #Mx = spilu(A) | |
| #M = LinearOperator(matvec=Mx.solve, shape=A.shape, dtype=A.dtype) | |
| M = None | |
| plt.gca().set_prop_cycle(plt.rcParams['axes.prop_cycle'][:4] | |
| + cycler(linestyle=['-', '--', '-.', ':'])) | |
| for algo, arg, fmt in solvers: | |
| r, shape = perf(algo, problem, M=M, max_matvec=max_matvec, max_time=max_time, tol=tol, **arg) | |
| plt.semilogy(r[:,0], r[:,1], label=fmt.format(**arg)) | |
| plt.title("{0} ({1[0]}x{1[1]})".format(problem.split('/')[-1], shape)) | |
| if jrow+1 == nrow: | |
| plt.xlabel('#matvec') | |
| if jcol == 0: | |
| plt.ylabel('rel. residual 2-norm') | |
| ymin, ymax = plt.ylim() | |
| plt.ylim(min(1e-1, ymin), ymax) | |
| if j == 0: | |
| plt.legend(loc='best') | |
| plt.suptitle('Comparison for same storage usage') | |
| plt.subplots_adjust(left=0.1, bottom=0.1, hspace=0.3) | |
| plt.savefig(os.path.join(BASEDIR, 'profiles.png'), dpi=150) | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment