Last active
November 30, 2023 11:04
-
-
Save proger/7782d1baa2e787e7398b81d0ebb1461f to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| "this module sets phasors to stun using gradient descent" | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.utils.parametrize as parametrize | |
| tau = 6.28 | |
| stun = 0.25*tau | |
| class Cyclic(nn.Module): | |
| def forward(self, x): | |
| return x % 1 | |
| class Phasor(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| # start close to zero from the other size | |
| self.phase = nn.Parameter(torch.tensor([1.9])) | |
| parametrize.register_parametrization(self, 'phase', Cyclic()) # force the angle to be within bounds in the weight space | |
| def forward(self): | |
| return self.phase * tau | |
| p = Phasor() | |
| print(p) | |
| for i in range(10): | |
| phasor = p() | |
| loss = (phasor - stun)**2 | |
| print(i, f'output {phasor}', f'loss {loss.item():.3f}', 'weight', p.phase.data) | |
| loss.backward() | |
| param = p.parametrizations.phase.original | |
| param.data -= 0.01 * param.grad | |
| param.grad = None |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment