Last active
November 7, 2023 15:35
-
-
Save peace098beat/abc44a62af51c2aba0f4c5ab25309364 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
元ネタ | |
https://github.com/arjendeetman/GCMMA-MMA-Python/blob/master/examples/mma_toy2.py | |
""" | |
import logging | |
from dataclasses import dataclass, field | |
import json | |
import numpy as np | |
@dataclass | |
class SolverState: | |
m: int = 2 | |
n: int = 3 | |
epsimin: float = 1e-7 | |
eeen: np.ndarray = field(default_factory=lambda: np.ones((3, 1))) | |
eeem: np.ndarray = field(default_factory=lambda: np.ones((2, 1))) | |
zeron: np.ndarray = field(default_factory=lambda: np.zeros((3, 1))) | |
zerom: np.ndarray = field(default_factory=lambda: np.zeros((2, 1))) | |
xval: np.ndarray = field(default_factory=lambda: np.array([[4, 3, 2]]).T) | |
xold1: np.ndarray = field(init=False) | |
xold2: np.ndarray = field(init=False) | |
xmin: np.ndarray = field(init=False) | |
xmax: np.ndarray = field(init=False) | |
low: np.ndarray = field(init=False) | |
upp: np.ndarray = field(init=False) | |
move: float = 1.0 | |
c: np.ndarray = field(default_factory=lambda: 1000 * np.ones((2, 1))) | |
d: np.ndarray = field(default_factory=lambda: np.ones((2, 1))) | |
a0: int = 1 | |
a: np.ndarray = field(init=False) | |
outeriter: int = 0 | |
maxoutit: int = 11 | |
kkttol: float = 0 | |
kktnorm: float = field(init=False) | |
outit: int = 0 | |
f0val: float = field(init=False) | |
df0dx: np.ndarray = field(init=False) | |
fval: np.ndarray = field(init=False) | |
dfdx: np.ndarray = field(init=False) | |
def __post_init__(self): | |
self.xold1 = self.xval.copy() | |
self.xold2 = self.xval.copy() | |
self.xmin = self.zeron.copy() | |
self.xmax = 5 * self.eeen | |
self.low = self.xmin.copy() | |
self.upp = self.xmax.copy() | |
self.a = self.zerom.copy() | |
self.kktnorm = self.kkttol + 10 | |
# Placeholder for toy2 function call, to be implemented by the user | |
self.f0val, self.df0dx, self.fval, self.dfdx = (None, None, None, None) | |
def save_state(self, filepath='solver_state.json'): | |
# Convert all numpy arrays to lists for JSON serialization | |
state_dict = {k: (v.tolist() if isinstance(v, np.ndarray) else v) for k, v in self.__dict__.items()} | |
with open(filepath, 'w') as f: | |
json.dump(state_dict, f, indent=4) | |
def load_state(self, filepath='solver_state.json'): | |
with open(filepath, 'r') as f: | |
state_dict = json.load(f) | |
for k, v in state_dict.items(): | |
setattr(self, k, np.array(v) if isinstance(v, list) else v) | |
class MMASolver: | |
def __init__(self, log_file_name="mma_toy2.log"): | |
self.state = SolverState() # Initialize the SolverState | |
self.logger = self.setup_logger(log_file_name) | |
self.logger.info("MMASolver initialized\n") | |
def setup_logger(self, log_file_name): | |
path = os.path.dirname(os.path.realpath(__file__)) | |
file = os.path.join(path, log_file_name) | |
logger = logging.getLogger() | |
handler = logging.FileHandler(file) | |
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') | |
handler.setFormatter(formatter) | |
logger.addHandler(handler) | |
logger.setLevel(logging.INFO) | |
return logger | |
def run_iter(self): | |
if self.state.outeriter == 0: | |
self.logger.info(f"outvector1 = {np.concatenate((np.array([self.state.outeriter]), self.state.xval.flatten()))}") | |
self.logger.info(f"outvector2 = {np.concatenate((np.array([self.state.f0val]), self.state.fval.flatten()))}\n") | |
while (self.state.kktnorm > self.state.kkttol) and (self.state.outit < self.state.maxoutit): | |
self.run_iter_one() | |
self.logger.info("Optimization Finished") | |
def run_iter_one(self): | |
self.state.outit += 1 | |
self.state.outeriter += 1 | |
# Solve the MMA subproblem at the point xval | |
xmma, ymma, zmma, lam, xsi, eta, mu, zet, s, self.state.low, self.state.upp = mmasub( | |
self.state.m, self.state.n, self.state.outeriter, self.state.xval, self.state.xmin, self.state.xmax, | |
self.state.xold1, self.state.xold2, self.state.f0val, self.state.df0dx, self.state.fval, | |
self.state.dfdx, self.state.low, self.state.upp, self.state.a0, self.state.a, self.state.c, | |
self.state.d, self.state.move | |
) | |
# Update vectors | |
self.state.xold2 = self.state.xold1.copy() | |
self.state.xold1 = self.state.xval.copy() | |
self.state.xval = xmma.copy() | |
# Re-calculate function values and gradients | |
self.state.f0val, self.state.df0dx, self.state.fval, self.state.dfdx = External_CONSOL(self.state.xval) | |
# Calculate the residual vector of the KKT conditions | |
residu, self.state.kktnorm, residumax = kktcheck( | |
self.state.m, self.state.n, xmma, ymma, zmma, lam, xsi, eta, mu, zet, s, | |
self.state.xmin, self.state.xmax, self.state.df0dx, self.state.fval, | |
self.state.dfdx, self.state.a0, self.state.a, self.state.c, self.state.d | |
) | |
# Log the vectors and KKT norm | |
self.logger.info(f"outvector1 = {np.concatenate((np.array([self.state.outeriter]), self.state.xval.flatten()))}") | |
self.logger.info(f"outvector2 = {np.concatenate((np.array([self.state.f0val]), self.state.fval.flatten()))}") | |
self.logger.info(f"kktnorm = {self.state.kktnorm}\n") | |
def save_state(self, filepath='solver_state.json'): | |
self.state.save_state(filepath) | |
self.logger.info(f"State saved to {filepath}") | |
def load_state(self, filepath='solver_state.json'): | |
self.state.load_state(filepath) | |
self.logger.info(f"State loaded from {filepath}") | |
# CONSOL!! | |
def External_CONSOL(xval): | |
f0val = xval[0][0]**2+xval[1][0]**2+xval[2][0]**2 | |
df0dx = 2*xval | |
fval1 = ((xval.T-np.array([[5, 2, 1]]))**2).sum()-9 | |
fval2 = ((xval.T-np.array([[3, 4, 3]]))**2).sum()-9 | |
fval = np.array([[fval1,fval2]]).T | |
dfdx1 = 2*(xval.T-np.array([[5, 2, 1]])) | |
dfdx2 = 2*(xval.T-np.array([[3, 4, 3]])) | |
dfdx = np.concatenate((dfdx1,dfdx2)) | |
return f0val,df0dx,fval,dfdx | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
javaでpytyon
https://chat.openai.com/share/dcd4018c-9c4c-4ec7-997f-9005fb15bc32