Skip to content

Instantly share code, notes, and snippets.

@jargnar
Last active November 16, 2024 05:16
Show Gist options
  • Save jargnar/9a13e71350f585bdd1e7ea8db3cd513e to your computer and use it in GitHub Desktop.
Save jargnar/9a13e71350f585bdd1e7ea8db3cd513e to your computer and use it in GitHub Desktop.
A simple Neural ODE for learning bioprocess dilution, growth and production rates. Generated by OpenAI o1.
"""
I'm building a simple hybrid model to predict
growth rate and production rate for a bioprocess.
My Neural ODE solves for growth rate and
production rate in context of bioprocess run data.
Here are my model equations:
First
D = dV/dt = f(t) + b
And then
dX/dt = μX - DX
dP/dt = pX - DP
In the end, I want to learn μ, and P.
f(t) is the feed rate programmed
by the bioreactor operator.
Therefore it can be looked up
in both train and test data.
"""
import torch
from torchdiffeq import odeint
# Constants and initial conditions
b = 0.1
y0 = torch.tensor([1.0, 0.0, 1.0]) # Initial values: X0, P0, V0
t = torch.linspace(0., 10., 100) # Time span: t0 to t1 with 100 points
# Feed rate function f(t)
f_t_values = torch.linspace(0.5, 1.0, len(t)) # Example feed rate data
f = lambda t: torch.interp(t, t, f_t_values)
# Neural network to learn μ and p
net = torch.nn.Sequential(
torch.nn.Linear(3, 16),
torch.nn.Tanh(),
torch.nn.Linear(16, 2)
)
# ODE function mirroring your equations
def ode(t, y):
X, P, V = y
D = (f(t) + b) / V
mu, p = net(torch.stack([X, P, V]))
dXdt = mu * X - D * X
dPdt = p * X - D * P
dVdt = f(t) + b
return torch.stack([dXdt, dPdt, dVdt])
# Observed data for training (replace with your actual data)
observed_t = t
observed_y = torch.randn(len(t), 3) # Shape: [time_steps, variables]
# Loss function
def loss_fn(pred_y, true_y):
return torch.mean((pred_y - true_y) ** 2)
# Optimizer
optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)
# Training loop
for epoch in range(500):
optimizer.zero_grad()
pred_y = odeint(ode, y0, observed_t)
loss = loss_fn(pred_y, observed_y)
loss.backward()
optimizer.step()
if epoch % 50 == 0:
print(f'Epoch {epoch}, Loss: {loss.item():.4f}')
# Validation (replace with your actual validation data)
with torch.no_grad():
pred_y_val = odeint(ode, y0, observed_t)
val_loss = loss_fn(pred_y_val, observed_y)
print(f'Validation Loss: {val_loss.item():.4f}')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment