Created
November 8, 2023 17:31
-
-
Save proger/a5b4a8ba6a555ff68d7439cd6072869d to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| "accumulate repeating characters: convolution frontend, RNN backend" | |
| import collections | |
| from itertools import islice | |
| def conv(iterable, n=2): | |
| "1d convolution" | |
| it = iter(iterable) | |
| window = collections.deque(islice(it, n-1), maxlen=n) | |
| for x in it: | |
| window.append(x) | |
| yield tuple(window) | |
| def forget(h, x): | |
| "data dependent forget gate of the RNN: updates memory given current input" | |
| a,b = x | |
| return h if a == b else '' | |
| def input_projection(x): | |
| "input projection of the RNN feature" | |
| a,b = x | |
| return b | |
| def output_projection(x): | |
| "output projection of the hidden state" | |
| return x | |
| def rnn(forget, input, output, xs): | |
| h = '' | |
| for x in xs: | |
| h = forget(h, x) + input(x) | |
| yield output(h) | |
| input = '_01112233445678' | |
| hidden = conv(input) | |
| for output in rnn(forget, input_projection, output_projection, hidden): | |
| print(output) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment