Skip to content

Instantly share code, notes, and snippets.

@slopp
Created August 7, 2017 14:55
Show Gist options
  • Save slopp/169f54e3d6743f26e1783fdf458b71ca to your computer and use it in GitHub Desktop.
Save slopp/169f54e3d6743f26e1783fdf458b71ca to your computer and use it in GitHub Desktop.
Initial Test of tfestimators
Sys.setenv(PATH = '/Users/Sean/.virtualenvs/r-tensorflow/bin:$PATH')
library(tfestimators)
library(readr)
library(stringr)
library(purrr)
col_names = c(
"age", "workclass", "fnlwgt", "education", "education_num",
"marital_status", "occupation", "relationship", "race", "gender",
"capital_gain", "capital_loss", "hours_per_week", "native_country",
"income_bracket"
)
train <- read_csv("data/train_raw.csv", col_names = col_names, na = "?")
test <- read_csv("data/test_raw.csv", col_names = col_names, na = "?", skip = 1)
test$income_bracket <- str_replace_all(test$income_bracket, fixed("."), "")
# input
infn <- function(data) {
input_fn(data,
response = 'income_bracket',
features = 'age'
)
}
# feature engineering
age <- column_numeric("age")
age_buckets <- column_bucketized(age, boundaries = c(18, 25, 30, 35, 40, 45, 50, 55, 60, 65))
columns <- feature_columns(age_buckets)
# model
model <- linear_classifier(
feature_columns = columns,
n_classes = 2,
label_vocabulary = c("<=50K", ">50K")
)
#train
model %>% train(infn(train))
# evaluate
res <- model %>% evaluate(infn(test))
# predict
pred <- model %>% predict(infn(test), predict_keys = prediction_keys()$CLASSES)
# flatten predictions, get confusion matrix
preds <- map_chr(pred, 'classes')
test$predicted <- preds
conf_matrix <- matrix(data = rep(0,4), nrow = 2, ncol = 2)
conf_matrix[1,1] <- sum(test$predicted == "<=50K" & test$income_bracket == "<=50K" )
conf_matrix[1,2] <- sum(test$predicted == ">50K" & test$income_bracket == "<=50K" )
conf_matrix[2,1] <- sum(test$predicted == "<=50K" & test$income_bracket == ">50K" )
conf_matrix[2,2] <- sum(test$predicted == ">50K" & test$income_bracket == ">50K" )
colnames(conf_matrix) <- c("Predicted <=50K", "Predicted >50K")
rownames(conf_matrix) <- c("Actual <=50K", "Actual >50K")
conf_matrix
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment