Last active
May 10, 2024 09:12
-
-
Save shink/ff8e666f17dd6f7f115cae2fae8e075b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
if __name__ == '__main__': | |
from kan import * | |
import torch | |
import torchvision | |
# create a KAN: 2D inputs, 1D output, and 5 hidden neurons. cubic spline (k=3), 5 grid intervals (grid=5). | |
model = KAN(width=[2, 5, 1], grid=5, k=3, device='cpu', seed=0) | |
# create dataset f(x,y) = exp(sin(pix)+y^2) | |
f = lambda x: torch.exp(torch.sin(torch.pi * x[:, [0]]) + x[:, [1]] ** 2) | |
dataset = create_dataset(f, n_var=2) | |
print(dataset['train_input'].shape) | |
print(dataset['train_label'].shape) | |
# plot KAN at initialization | |
model(dataset['train_input']) | |
model.plot(beta=100) | |
# train the model | |
model.train(dataset, opt="LBFGS", steps=20, lamb=0.01, lamb_entropy=10.) | |
model.plot() | |
model.prune() | |
model.plot(mask=True) | |
model = model.prune() | |
model(dataset['train_input']) | |
model.plot() | |
model.train(dataset, opt="LBFGS", steps=50) | |
mode = "auto" # "manual" | |
if mode == "manual": | |
# manual mode | |
model.fix_symbolic(0, 0, 0, 'sin') | |
model.fix_symbolic(0, 1, 0, 'x^2') | |
model.fix_symbolic(1, 0, 0, 'exp') | |
elif mode == "auto": | |
# automatic mode | |
lib = ['x', 'x^2', 'x^3', 'x^4', 'exp', 'log', 'sqrt', 'tanh', 'sin', 'abs'] | |
model.auto_symbolic(lib=lib) | |
model.train(dataset, opt="LBFGS", steps=50) # The line of code that reported the error | |
model.symbolic_formula()[0][0] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment