Skip to content

Instantly share code, notes, and snippets.

@jayhuang75
Created December 23, 2021 16:24
Show Gist options
  • Select an option

  • Save jayhuang75/b232bd4ac9ff66099d3e21a744e2a37b to your computer and use it in GitHub Desktop.

Select an option

Save jayhuang75/b232bd4ac9ff66099d3e21a744e2a37b to your computer and use it in GitHub Desktop.
go-train-delay-ml-pipeline-api
#[post("/predict")]
pub async fn predict(
// _state: web::Data<AppState>,
predict: web::Json<PredictData>,
) -> Result<HttpResponse, AppError> {
// Load the model
let lr_model: LinearRegression<f64, DenseMatrix<f64>> = {
let mut buf: Vec<u8> = Vec::new();
File::open(&"./models/go_train_delay_lr_2021-12-22UTC.model")
.and_then(|mut f| f.read_to_end(&mut buf))
.expect("Can not load model");
bincode::deserialize(&buf).expect("Can not deserialize the model")
};
//Predict class labels
let df = df! {
"date" => &[predict.date.as_str()],
"departure" => &[predict.departure.as_str()],
"depart_scheduled"=> &[predict.depart_scheduled.as_str()],
"destination" => &[predict.destination.as_str()],
"arrival_scheduled"=> &[predict.arrival_scheduled.as_str()],
}?;
// convert feature to smartcore data matrix
let x = feature_to_matrix(&df).await?;
let res = lr_model.predict(&x);
let delay: f64;
match res {
Ok(res) => {
delay = res[0];
}
Err(err) => {
return Ok(HttpResponse::InternalServerError().json(AppError{
message: Some(err.to_string()),
cause: Some(err.to_string()),
error_type: AppErrorType::SerdeJsonError,
}));
}
}
Ok(HttpResponse::Ok().json(PredictResp{
delay_mins: delay
}))
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment