Skip to content

Instantly share code, notes, and snippets.

@jargnar
Last active November 14, 2024 08:50
Show Gist options
  • Save jargnar/864af4a9043d61b7a3bc831b80da2a28 to your computer and use it in GitHub Desktop.
Save jargnar/864af4a9043d61b7a3bc831b80da2a28 to your computer and use it in GitHub Desktop.
A simple Neural ODE for learning bioprocess dilution, growth and production rates. Generated by Claude 3.5 Sonnet.
"""
I'm building a simple hybrid model to predict growth rate and production rate for a bioprocess.
Can you give me the most simplest form of Neural ODE in JAX.
My ODE im trying to solve is a growth rate and production rate in context of bioprocess run data.
Here are my model equations:
I want to first solve the ODE
D = dV/dt = f(t) + b
And then use that to solve the below system of ODEs
dX/dt = μX - DX
dP/dt = pX - DP
In the end, I want to learn μ, and P.
f(t) is the feed rate programmed by the bioreactor operator.
Therefore it can be looked up in both train and test data.
Give me simple train() and validate() functions as well.
"""
import jax
import jax.numpy as jnp
from jax.experimental.ode import odeint
import optax
# Core ODE functions
def dilution_rate(t, feed_rate, b):
"""Compute dilution rate from feed rate and base parameter."""
return feed_rate(t) + b
def bioprocess_dynamics(state, t, params, feed_rate):
"""Pure function for bio-process dynamics."""
X, P = state
mu, p, b = params
D = dilution_rate(t, feed_rate, b)
dX_dt = mu * X - D * X
dP_dt = p * X - D * P
return jnp.array([dX_dt, dP_dt])
# Model prediction
def predict(params, times, feed_rate, X0, P0):
"""Forward pass through the model."""
initial_state = jnp.array([X0, P0])
trajectory = odeint(
bioprocess_dynamics,
initial_state,
times,
params,
feed_rate
)
return trajectory
# Loss function
def loss_fn(params, batch):
"""Compute MSE loss between predictions and true values."""
times, feed_rate, X0, P0, X_true, P_true = batch
predictions = predict(params, times, feed_rate, X0, P0)
X_pred, P_pred = predictions.T
return jnp.mean((X_pred - X_true)**2 + (P_pred - P_true)**2)
# Training
def train(params, train_data, learning_rate=1e-3, epochs=100):
"""Train the model using gradient descent."""
optimizer = optax.adam(learning_rate)
opt_state = optimizer.init(params)
grad_fn = jax.grad(loss_fn)
def step(params, opt_state, batch):
grads = grad_fn(params, batch)
updates, opt_state = optimizer.update(grads, opt_state)
params = optax.apply_updates(params, updates)
return params, opt_state
for epoch in range(epochs):
for batch in train_data:
params, opt_state = step(params, opt_state, batch)
if epoch % 10 == 0:
loss = loss_fn(params, batch)
print(f"Epoch {epoch}, Loss: {loss:.6f}")
return params
# Validation
def validate(params, val_data):
"""Validate model on validation data."""
total_loss = 0.0
for batch in val_data:
loss = loss_fn(params, batch)
total_loss += loss
return total_loss / len(val_data)
# Example usage:
"""
# Initialize parameters (mu, p, b)
params = jnp.array([0.1, 0.05, 0.01])
# Create data batches
def feed_rate(t):
return some_function_of_t
train_data = [
(times, feed_rate, X0, P0, X_true, P_true),
# ... more batches
]
val_data = [
(times, feed_rate, X0, P0, X_true, P_true),
# ... more batches
]
# Train
trained_params = train(params, train_data)
# Validate
val_loss = validate(trained_params, val_data)
print(f"Validation loss: {val_loss:.6f}")
"""
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment