Skip to content

Instantly share code, notes, and snippets.

@xqm32
Created December 13, 2022 09:46
Show Gist options
  • Select an option

  • Save xqm32/ddd9ce8afe6f42313cbd31f0a74069cb to your computer and use it in GitHub Desktop.

Select an option

Save xqm32/ddd9ce8afe6f42313cbd31f0a74069cb to your computer and use it in GitHub Desktop.
Yet Another Solver
import matplotlib.pyplot as plt
import numpy as np
class Solution:
def __init__(
self,
a: float,
b: list[float],
c: list[float],
d: float,
) -> None:
self.a: float = a
self.b: list[float] = b.copy()
self.c: list[float] = c.copy()
self.d: float = d
# breakpoints
self.bps: list[float] = [-ci / bi for bi, ci in zip(self.b, self.c)]
def f(self, x: float):
return (
self.a / 2 * x * x
+ self.d * x
+ sum(bi * x + ci for bi, ci in zip(self.b, self.c) if bi * x + ci >= 0)
)
def fp(self, x: float):
return (
self.a * x
+ self.d
+ sum(bi for bi, ci in zip(self.b, self.c) if bi * x + ci >= 0)
)
def fp_zero(self, x: float):
z = (
-(self.d + sum(bi for bi, ci in zip(self.b, self.c) if bi * x + ci >= 0))
/ self.a
)
plt.plot(z, 0, "o", color="#9b59b6", markersize=3)
def plt_fp(self, f=False, fp=True):
xs = np.linspace(min(self.bps) - 1, max(self.bps) + 1, 10000)
bps = [bp + 1e-4 for bp in self.bps]
# axis X
plt.xlabel("x")
plt.axhline(color="grey", linestyle="--")
# axis Y
plt.ylabel("y")
plt.axvline(color="grey", linestyle="--")
if f:
f_xs = np.array([self.f(x) for x in xs])
f_bps = np.array([self.f(x) for x in bps])
plt.plot(xs, f_xs, color="#16a085")
if fp:
fp_xs = np.array([self.fp(x) for x in xs])
fp_bps = np.array([self.fp(x) for x in bps])
plt.plot(xs, fp_xs)
plt.plot(bps, fp_bps, "o", color="#2980b9", markersize=3)
for bp, fp_bp in zip(bps, fp_bps):
plt.axline(
(bp, fp_bp), slope=a, color="grey", linestyle=":", linewidth=0.5
)
for bp in bps:
self.fp_zero(bp)
plt.show()
m = 3
a = 1.5
d = 1
b = [1, 1.2, -0.9]
c = [0.1, -1.4, -1.2]
# m = 10
# a = abs(np.random.normal(0, 1))
# d = np.random.normal(0, 1)
# b = list(np.random.normal(0, 1, m))
# c = list(np.random.normal(0, 1, m))
Solution(a, b, c, d).plt_fp()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment