Last active
May 26, 2021 16:46
-
-
Save bwasti/cecc99ce23787faf5616d6ca6e4bf595 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import builtins | |
import plotille | |
import numpy as np | |
import time | |
class Reprint: | |
def __init__(self): | |
self.h = 0 | |
self.s = "" | |
def add(self, s): | |
self.s += ("\n" if self.s else "") + s | |
def flush(self): | |
esc = chr(27) | |
back = esc + "[1F" | |
clr_line = esc + "[0K" | |
self.print(back * self.h, end="") | |
lines = self.s.split("\n") | |
self.h = len(lines) | |
for l in lines: | |
self.print(f"{l}{clr_line}") | |
self.s = "" | |
def __enter__(self): | |
print(chr(27) + "[?25l", end="") | |
self.print = builtins.print | |
def passive_print(*args, **kwargs): | |
self.add(" ".join([str(arg) for arg in args])) | |
builtins.print = passive_print | |
return self | |
def __exit__(self, type, value, traceback): | |
builtins.print = self.print | |
print(chr(27) + "[?25h", end="") | |
def horiz_concat(s0, s1, padding=1): | |
l0 = s0.split("\n") | |
max_len = max([len(l) for l in l0]) + padding | |
l0 = [l.ljust(max_len, " ") for l in l0] | |
l1 = s1.split("\n") | |
ls = zip(l0, l1) | |
ls = [a + b for a, b in ls] | |
return "\n".join(ls) | |
def train(iters, rp): | |
loss = plotille.Figure() | |
loss.width = 60 | |
loss.height = 10 | |
loss.x_label = "iters" | |
for i in range(iters): | |
# simulated training data | |
x = np.arange(i + 5) | |
y = 10 / (np.log(x + 1) + 1) | |
weights0 = np.random.normal(size=1000) | |
weights1 = np.random.normal(size=1000) | |
loss.set_x_limits(min_=float(np.min(x)), max_=float(np.max(x))) | |
loss.set_y_limits(min_=float(np.min(y)), max_=float(np.max(y))) | |
loss.plot(x, y, lc="green") | |
print(f"loss: {np.min(y)}\n") | |
print(loss.show()) | |
print("\nweight distributions:\n") | |
w0 = plotille.histogram(weights0, height=10, bins=10, width=20) | |
w1 = plotille.histogram(weights1, height=10, bins=10, width=20) | |
print(horiz_concat(w0, w1)) | |
rp.flush() | |
time.sleep(0.2) | |
if __name__ == "__main__": | |
with Reprint() as rp: | |
try: | |
train(30, rp) | |
except KeyboardInterrupt: | |
pass |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment