Skip to content

Instantly share code, notes, and snippets.

@Orbifold
Created August 23, 2018 12:21
Show Gist options
  • Select an option

  • Save Orbifold/4f2260a482822101427cb0f3c664cf56 to your computer and use it in GitHub Desktop.

Select an option

Save Orbifold/4f2260a482822101427cb0f3c664cf56 to your computer and use it in GitHub Desktop.
R version of TF estimators. See https://tensorflow.rstudio.com.
library(tfestimators)
response <- function() "Species"
features <- function() setdiff(names(iris), response())
# split into train, test datasets
set.seed(123)
partitions <- modelr::resample_partition(iris, c(test = 0.2, train = 0.8))
iris_train <- as.data.frame(partitions$train)
iris_test <- as.data.frame(partitions$test)
# construct feature columns
feature_columns <- feature_columns(
column_numeric(features())
)
# construct classifier
classifier <- dnn_classifier(
feature_columns = feature_columns,
hidden_units = c(10, 20, 10),
n_classes = 3
)
# construct input function
iris_input_fn <- function(data) {
input_fn(data, features = features(), response = response())
}
# train classifier with training dataset
train(classifier, input_fn = iris_input_fn(iris_train))
# valuate with test dataset
predictions <- predict(classifier, input_fn = iris_input_fn(iris_test))
evaluation <- evaluate(classifier, input_fn = iris_input_fn(iris_test))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment