Created
June 28, 2022 06:46
-
-
Save ytbilly3636/034b4f8bc8f8f2b6a0f72520d9ecf22d to your computer and use it in GitHub Desktop.
Visualize reservoir responses to sound inputs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import copy | |
class Reservoir(object): | |
def __init__(self, i_size, r_size, i_coef=1.0, r_coef=0.999): | |
self._w_i = np.random.uniform(-i_coef, i_coef, (r_size, i_size)).astype(np.float32) | |
w_r = np.random.rand(r_size, r_size).astype(np.float32) | |
self._w_r = w_r / max(abs(np.linalg.eig(w_r)[0])) * r_coef | |
def reset(self): | |
self._x = np.zeros((1, self._w_r.shape[0]), dtype=np.float32) | |
def __call__(self, u): | |
self._x = np.tanh(u.dot(self._w_i.T) + self._x.dot(self._w_r.T), dtype=np.float32) | |
return copy.deepcopy(self._x) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
''' | |
Put these files in the same directory and execute | |
$ python reservoir visualization.py | |
''' | |
import sys | |
import pyaudio | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from reservoir import Reservoir | |
# audio | |
pa = pyaudio.PyAudio() | |
stream = pa.open( | |
format=pyaudio.paInt16, | |
channels=1, | |
rate=44100, | |
input=True, | |
) | |
# reservoir | |
# import from reservoir.py | |
# if r_coef is small, reservoir state converges soon | |
# if r_coef is big, reservoir state diverces | |
RES_SIZE = 100 | |
res = Reservoir(1, RES_SIZE, r_coef=0.8) | |
res.reset() | |
# 2d plot | |
# BUF_PLOT = 50 | |
# buffer_2dplot = np.zeros((BUF_PLOT, ), np.float32) | |
# 3d plot | |
BUF_PLOT = 5 | |
buffer_3dplot = np.zeros((BUF_PLOT, 3), np.float32) | |
fig = plt.figure() | |
ax = fig.gca(projection='3d') | |
c = np.linspace(0, 1, BUF_PLOT) | |
# loop | |
init_count = 0 | |
x_buf = [] | |
try: | |
while True: | |
data = stream.read(4410) | |
data = np.frombuffer(data, dtype=np.int16) | |
mean = np.mean(data) / 1024 | |
x = res(np.array([[mean]], dtype=np.float32)) | |
# 3 nodes whose variances During first 20 steps are big are selected | |
if init_count < 20: | |
x_buf.append(x) | |
continue | |
elif init_count == 20: | |
res.reset() | |
x_buf = np.concatenate(x_buf, axis=0) | |
x_var = np.var(x_buf, axis=0) | |
order = np.argsort(x_var) | |
init_count += 1 | |
''' | |
plt.clf() | |
buffer_2dplot[0:-1] = buffer_2dplot[1:] | |
buffer_2dplot[-1] = mean | |
plt.ylim(-1, 1) | |
plt.plot(buffer_2dplot) | |
plt.pause(0.001) | |
''' | |
# Latest 3 reservoir states are visualized | |
buffer_3dplot[0:-1] = buffer_3dplot[1:] | |
buffer_3dplot[-1][0] = x[0][order[0]] | |
buffer_3dplot[-1][1] = x[0][order[1]] | |
buffer_3dplot[-1][2] = x[0][order[2]] | |
ax.cla() | |
ax.set_xlim(-0.2, 0.2) | |
ax.set_ylim(-0.2, 0.2) | |
ax.set_zlim(-0.2, 0.2) | |
ax.scatter(buffer_3dplot[:, 0], buffer_3dplot[:, 1], buffer_3dplot[:, 2], c=c, cmap='viridis') | |
plt.pause(0.001) | |
except KeyboardInterrupt: | |
pa.terminate() | |
sys.exit() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment