Created
May 16, 2023 15:59
-
-
Save ha7ilm/d8b946b8acfa0bfb35303e63e043fe7c to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#We only need to normalize the inputs of a neural network, not necessarily the outputs. | |
#However, the choice of the activation function is important, e.g. it doesn't work with `jnp.tanh`. | |
import jax.numpy as jnp | |
from jax import grad, jit, vmap | |
import jax | |
from jax import random | |
from flax import linen as nn | |
from flax.training import train_state | |
from optax import adam | |
import plotly.graph_objects as go | |
# Generate the data | |
x = jnp.linspace(-jnp.pi, jnp.pi, 200).reshape(-1, 1) # Reshape x to (N, 1) | |
y = 1000*jnp.sin(x)+3000 | |
# Normalize inputs | |
x = (x - jnp.mean(x)) / jnp.std(x) | |
# Create the neural network | |
class SineModel(nn.Module): | |
def setup(self): | |
self.layer1 = nn.Dense(features=32) | |
self.layer2 = nn.Dense(features=32) | |
self.output_layer = nn.Dense(features=1) | |
def __call__(self, x): | |
x = jax.nn.softplus(self.layer1(x)) | |
x = jax.nn.softplus(self.layer2(x)) | |
return self.output_layer(x) | |
model = SineModel() | |
# Define loss function | |
def loss_fn(params, batch): | |
inputs, targets = batch | |
preds = model.apply(params, inputs) | |
return jnp.mean((targets - preds)**2) | |
# Initialize model and optimizer | |
rng = random.PRNGKey(0) | |
params = model.init(rng, x) | |
optimizer = adam(learning_rate=0.01) | |
# Define training step | |
@jit | |
def train_step(state, batch): | |
grads = grad(loss_fn)(state.params, batch) | |
return state.apply_gradients(grads=grads) | |
# Prepare the data | |
data = (x, y) | |
# Training loop | |
state = train_state.TrainState.create(apply_fn=model.apply, params=params, tx=optimizer) | |
for i in range(10000): | |
state = train_step(state, data) | |
if i % 100 == 0: | |
print('Loss at step {}: {}'.format(i, loss_fn(state.params, data))) | |
# Plot data and model output | |
predictions = model.apply(state.params, x) | |
fig = go.Figure() | |
fig.add_trace(go.Scatter(x=x.ravel(), y=y.ravel(), mode='markers', name='Data')) | |
fig.add_trace(go.Scatter(x=x.ravel(), y=predictions.ravel(), mode='lines', name='Model Output')) | |
fig.update_layout(title='Sine Function and Model Output', xaxis_title='X', yaxis_title='Y') | |
fig.write_html('sineout.html') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment