Last active
August 10, 2020 09:26
-
-
Save bubbobne/d92523258b2883cc90229e14ed5009ba to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
library(tidyverse) | |
library(caret) | |
#' Predict values from an ordinal brm regression model. | |
#' | |
#' @param model the fitted model. | |
#' @param data a datafram with all explanatory variables and labels. | |
#' @return a dataframe with a probability value for each labels, for each item, and true value. | |
#' @examples | |
#' add(1, 1) | |
#' add(10, 1) | |
getPredictions <- function(model ,data, labeled_column) { | |
output = predict( | |
model, | |
newdata = data, | |
re_formula = NA, | |
) | |
labels = trimws(gsub('P\\(Y =', "", gsub('\\)','',colnames(output)))) | |
output = as.data.frame(output) | |
colnames(output)=labels | |
output = output %>% rownames_to_column() %>% gather(predict_label, value, -rowname) %>% group_by(rowname) %>% filter(rank(-value) == 1) | |
data = data %>% rownames_to_column() %>% select (rowname,all_of(labeled_column)) | |
output = output %>% inner_join(data, by = "rowname",na_matches="never") | |
output$is_ok = output$predict_label == output[,c(labeled_column)] | |
confusion_matrix <- confusionMatrix(output$predict_label, output[,c(labeled_column) | |
print(confusion_matrix) | |
return(output) | |
} |
Author
bubbobne
commented
Aug 10, 2020
- filter(rank(-value) == 1) not get a value where there is several maximum value
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment