Skip to content

Instantly share code, notes, and snippets.

@rejuvyesh
Created April 14, 2022 20:45
Show Gist options
  • Save rejuvyesh/6c428ea12154edbb36cd4359fa75e4c9 to your computer and use it in GitHub Desktop.
Save rejuvyesh/6c428ea12154edbb36cd4359fa75e4c9 to your computer and use it in GitHub Desktop.
# install tinycudann via
# pip install git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch
import torch
import tinycudann as tcnn
import time
class TCNNMatrixExponentEstimator1(torch.nn.Module):
def __init__(self, hidden=16) -> None:
super().__init__()
config = {
"network": {
"otype": "FullyFusedMLP",
"activation": "ReLU",
"output_activation": "None",
"n_neurons": hidden,
"n_hidden_layers": 1,
}
}
self.model = tcnn.Network(
4, 4, config["network"],
)
def forward(self, x):
return self.model(x)
def f(x):
return torch.matrix_exp(x.reshape((2,2))).reshape((4,))
def apply_matrix_exponential(x):
return torch.stack([f(x_i) for x_i in torch.unbind(x)])
def train():
device = torch.device("cuda")
dtype = torch.float16
epochs = 10000
trainx = torch.randn(10000, 2*2).to(device)
trainy = apply_matrix_exponential(trainx)
testx = torch.randn(10000, 2*2).to(device)
testy = apply_matrix_exponential(testx)
trainx = trainx.to(dtype=dtype)
trainy = trainy.to(dtype=dtype)
testx = testx.to(dtype=dtype)
testy = testy.to(dtype=dtype)
model = TCNNMatrixExponentEstimator1().to(device, dtype=torch.float32)
adam = torch.optim.Adam(model.parameters(), lr = 1e-3)
loss_fn = torch.nn.MSELoss()
print('Initial Train Loss: {:.4f}'.format(loss_fn(model(trainx), trainy)))
print('Initial Test Loss: {:.4f}'.format(loss_fn(model(testx), testy)))
for _ in range(3):
t_start = time.time()
for _ in range(epochs):
adam.zero_grad()
loss_fn(model(trainx), trainy).backward()
adam.step()
print('Took: {:.2f} seconds'.format(time.time() - t_start))
print('Train Loss: {:.4f}'.format(loss_fn(model(trainx), trainy)))
print('Test Loss: {:.4f}'.format(loss_fn(model(testx), testy)))
if __name__ == '__main__':
train()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment