Last active
November 14, 2024 08:50
-
-
Save jargnar/864af4a9043d61b7a3bc831b80da2a28 to your computer and use it in GitHub Desktop.
A simple Neural ODE for learning bioprocess dilution, growth and production rates. Generated by Claude 3.5 Sonnet.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
I'm building a simple hybrid model to predict growth rate and production rate for a bioprocess. | |
Can you give me the most simplest form of Neural ODE in JAX. | |
My ODE im trying to solve is a growth rate and production rate in context of bioprocess run data. | |
Here are my model equations: | |
I want to first solve the ODE | |
D = dV/dt = f(t) + b | |
And then use that to solve the below system of ODEs | |
dX/dt = μX - DX | |
dP/dt = pX - DP | |
In the end, I want to learn μ, and P. | |
f(t) is the feed rate programmed by the bioreactor operator. | |
Therefore it can be looked up in both train and test data. | |
Give me simple train() and validate() functions as well. | |
""" | |
import jax | |
import jax.numpy as jnp | |
from jax.experimental.ode import odeint | |
import optax | |
# Core ODE functions | |
def dilution_rate(t, feed_rate, b): | |
"""Compute dilution rate from feed rate and base parameter.""" | |
return feed_rate(t) + b | |
def bioprocess_dynamics(state, t, params, feed_rate): | |
"""Pure function for bio-process dynamics.""" | |
X, P = state | |
mu, p, b = params | |
D = dilution_rate(t, feed_rate, b) | |
dX_dt = mu * X - D * X | |
dP_dt = p * X - D * P | |
return jnp.array([dX_dt, dP_dt]) | |
# Model prediction | |
def predict(params, times, feed_rate, X0, P0): | |
"""Forward pass through the model.""" | |
initial_state = jnp.array([X0, P0]) | |
trajectory = odeint( | |
bioprocess_dynamics, | |
initial_state, | |
times, | |
params, | |
feed_rate | |
) | |
return trajectory | |
# Loss function | |
def loss_fn(params, batch): | |
"""Compute MSE loss between predictions and true values.""" | |
times, feed_rate, X0, P0, X_true, P_true = batch | |
predictions = predict(params, times, feed_rate, X0, P0) | |
X_pred, P_pred = predictions.T | |
return jnp.mean((X_pred - X_true)**2 + (P_pred - P_true)**2) | |
# Training | |
def train(params, train_data, learning_rate=1e-3, epochs=100): | |
"""Train the model using gradient descent.""" | |
optimizer = optax.adam(learning_rate) | |
opt_state = optimizer.init(params) | |
grad_fn = jax.grad(loss_fn) | |
def step(params, opt_state, batch): | |
grads = grad_fn(params, batch) | |
updates, opt_state = optimizer.update(grads, opt_state) | |
params = optax.apply_updates(params, updates) | |
return params, opt_state | |
for epoch in range(epochs): | |
for batch in train_data: | |
params, opt_state = step(params, opt_state, batch) | |
if epoch % 10 == 0: | |
loss = loss_fn(params, batch) | |
print(f"Epoch {epoch}, Loss: {loss:.6f}") | |
return params | |
# Validation | |
def validate(params, val_data): | |
"""Validate model on validation data.""" | |
total_loss = 0.0 | |
for batch in val_data: | |
loss = loss_fn(params, batch) | |
total_loss += loss | |
return total_loss / len(val_data) | |
# Example usage: | |
""" | |
# Initialize parameters (mu, p, b) | |
params = jnp.array([0.1, 0.05, 0.01]) | |
# Create data batches | |
def feed_rate(t): | |
return some_function_of_t | |
train_data = [ | |
(times, feed_rate, X0, P0, X_true, P_true), | |
# ... more batches | |
] | |
val_data = [ | |
(times, feed_rate, X0, P0, X_true, P_true), | |
# ... more batches | |
] | |
# Train | |
trained_params = train(params, train_data) | |
# Validate | |
val_loss = validate(trained_params, val_data) | |
print(f"Validation loss: {val_loss:.6f}") | |
""" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment