Skip to content

Instantly share code, notes, and snippets.

@S1ro1
Created October 31, 2024 09:19
Show Gist options
  • Save S1ro1/0fefc8972ebab03463beccd08ab3c162 to your computer and use it in GitHub Desktop.
Save S1ro1/0fefc8972ebab03463beccd08ab3c162 to your computer and use it in GitHub Desktop.
import torch
def print_test_end():
print("---------------")
def test_vectors_bwd():
print("TEST VECTORS BTW")
a = torch.tensor([[1.0, -2.0, 3.0]], requires_grad=True)
b = torch.tensor([[0.5, 0.5, 1.0]], requires_grad=True)
c = torch.dot(a[0], b[0])
c.backward()
print(a.grad)
print(b.grad)
print_test_end()
def test_matrix_vector_bwd():
print("TEST MATRIX VECTOR BWD")
a = torch.tensor([[1.0, -1.0], [0.0, 1.0]], requires_grad=True)
b = torch.tensor([[2.0, 3.0]], requires_grad=True)
a.retain_grad()
b.retain_grad()
c = a @ b.T
c.backward(torch.tensor([[1.0, 1.0]]).T)
print(b.grad)
print_test_end()
def test_sum_bwd():
print("TEST SUM BWD")
a = torch.tensor([[1.0, -2.5, 3.0]], requires_grad=True)
b = torch.sum(a)
b.backward()
print(a.grad)
print_test_end()
def test_relu_sum():
print("TEST RELU SUM")
a = torch.tensor([[1.0, -2.0, 3.0]], requires_grad=True)
b = torch.relu(a)
c = torch.sum(b)
c.backward()
print(a.grad)
print_test_end()
def test_add_sum():
print("TEST ADD SUM")
a = torch.tensor([[1.0, -2.0, 3.0]], requires_grad=True)
b = torch.tensor([[1., 2., 3.]], requires_grad=True)
c = a + b
d = torch.sum(c)
d.backward()
print(a.grad)
print(b.grad)
print_test_end()
if __name__ == "__main__":
test_vectors_bwd()
test_matrix_vector_bwd()
test_sum_bwd()
test_relu_sum()
test_add_sum()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment