Created
June 26, 2020 22:58
-
-
Save logancyang/20ac537b9a6de84ffa040748f82c2101 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Tensor: | |
def __init__(self, data, _parents=(), _op=''): | |
self.data = data | |
self.grad = 0 | |
self._backward = lambda: None | |
# A set that contains previous (parent) nodes that produced | |
# the current node with operation _op | |
self._prev = set(_parents) | |
self._op = _op | |
def __add__(self, other): | |
out = Tensor(self.data + other.data, (self, other), '+') | |
def _backward(): | |
self.grad += out.grad | |
other.grad += out.grad | |
out._backward = _backward | |
return out | |
def __mul__(self, other): | |
out = Tensor(self.data * other.data, (self, other), '*') | |
def _backward(): | |
self.grad += other.data * out.grad | |
other.grad += self.data * out.grad | |
out._backward = _backward | |
return out | |
def __pow__(self, other): | |
out = Tensor(self.data**other, (self,), f'**{other}') | |
def _backward(): | |
self.grad += (other * self.data**(other-1)) * out.grad | |
out._backward = _backward | |
return out | |
def __neg__(self): | |
return self * (-1) | |
def __sub__(self, other): | |
return self + (-other) | |
... |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment