Created
May 22, 2018 10:26
-
-
Save CSJCampbell/996e398908eb700866ccd98382781ee0 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#### Getting started with sparklyr #### | |
# ccampbell | |
# 15-May-2018 | |
# This tutorial will demonstrate that | |
# a local Spark instance is a useful | |
# tool for a data scientist working on | |
# mid-sized datasets. | |
# Spark requires Java 8 | |
Sys.setenv("JAVA_HOME" = "C:/Program Files/Java/jre1.8.0_172") | |
#### Installation #### | |
install.packages("devtools") | |
# package under active development | |
devtools::install_github("rstudio/sparklyr") | |
# install Spark | |
sparklyr::spark_install(version = "2.1.0") | |
#### Spark connection #### | |
library(dplyr) | |
library(sparklyr) | |
sc <- spark_connect(master = "local") | |
# spark memory | |
src_tbls(sc) | |
# copy for clarity | |
mtcars_sp <- mtcars | |
# e.g. generate extra tables to merge with main dataset | |
mtcars_tbl <- copy_to(dest = sc, df = mtcars_sp) | |
# remove copy from R | |
rm(mtcars_sp) | |
class(mtcars_tbl) | |
# [1] "tbl_spark" "tbl_sql" "tbl_lazy" "tbl" | |
# object exists as table in spark | |
src_tbls(sc) | |
# [1] "mtcars" | |
#### Spark SQL #### | |
# get data subset | |
m1 <- filter(mtcars_tbl, | |
cyl == 6 & am == 0) %>% | |
select(mpg, wt, qsec) | |
m1 | |
# # Source: lazy query [?? x 3] | |
# # Database: spark_connection | |
# mpg wt qsec | |
# <dbl> <dbl> <dbl> | |
# 1 21.4 3.22 19.4 | |
# 2 18.1 3.46 20.2 | |
# 3 19.2 3.44 18.3 | |
# 4 17.8 3.44 18.9 | |
# another spark df LAZY | |
class(m1) | |
# bigger than needed for small data | |
# same size-ish for 1e6 rows | |
# until data `collect`-ed back to R | |
object.size(m1) | |
# for comparison | |
object.size(filter(mtcars, | |
cyl == 6 & am == 0) %>% | |
select(mpg, wt, qsec)) | |
# because it is representation of data | |
# plus connection information | |
# Spark SQL to count rows | |
summarise(mtcars_tbl, n = n()) | |
# # Source: lazy query [?? x 1] | |
# # Database: spark_connection | |
# n | |
# <dbl> | |
# 1 32.0 | |
# group_by supported | |
m2 <- summarise(group_by(mtcars_tbl, am), n = n()) | |
m2 | |
# # Source: lazy query [?? x 2] | |
# # Database: spark_connection | |
# am n | |
# <dbl> <dbl> | |
# 1 0 19.0 | |
# 2 1.00 13.0 | |
# view Spark SQL query itself | |
dbplyr::sql_render(m2) | |
# <SQL> SELECT `am`, count(*) AS `n` | |
# FROM `mtcars` | |
# GROUP BY `am` | |
#### Transforming Spark DataFrames #### | |
objects("package:sparklyr", | |
pattern = "^ft_") | |
# [1] "ft_binarizer" | |
# [2] "ft_bucketed_random_projection_lsh" | |
# [3] "ft_bucketizer" | |
# [4] "ft_chisq_selector" | |
# ... | |
# cut | |
# breakpoints for 3 equally sized bins | |
splt <- summarise(mtcars_tbl, | |
# mapped to Spark SQL | |
T0 = min(disp), | |
# Hive language | |
T1 = percentile_approx(disp, 1/3), | |
T2 = percentile_approx(disp, 2/3), | |
T3 = max(disp)) | |
class(splt) | |
splt | |
# # Source: lazy query [?? x 4] | |
# # Database: spark_connection | |
# T0 T1 T2 T3 | |
# <dbl> <dbl> <dbl> <dbl> | |
# 1 71.1 145 301 472 | |
# collect means draw down data | |
splt_df <- collect(splt) | |
# just a normal tbl_df in R | |
class(splt_df) | |
# cut using breaks | |
mtcars_tbl <- ft_bucketizer( | |
x = mtcars_tbl, | |
input_col = "disp", | |
output_col = "fct_disp", | |
# vector of cutpoints | |
# collect first! | |
splits = unlist(splt_df)) | |
# column has been created | |
mtcars_tbl %>% select(mpg, am, disp, fct_disp) %>% head | |
# # Source: lazy query [?? x 4] | |
# # Database: spark_connection | |
# mpg am disp fct_disp | |
# <dbl> <dbl> <dbl> <dbl> | |
# 1 21.0 1.00 160 1.00 | |
# 2 21.0 1.00 160 1.00 | |
# 3 22.8 1.00 108 0 | |
# 4 21.4 0 258 1.00 | |
# 5 18.7 0 360 2.00 | |
# 6 18.1 0 225 1.00 | |
#### Updating Spark DataFrames #### | |
objects("package:sparklyr", pattern = "^sdf_") | |
sdf_nrow(mtcars_tbl) | |
# [1] 32 | |
# update with arbitrary column | |
# note that data must be passed to Spark | |
# so currently only possible to `mutate` | |
# with single values or expressions with | |
# combinations of columns and Hive functions | |
# | |
# # adds list to each row | |
# mtcars_tbl %>% select(mpg, cyl) %>% mutate(info1 = rn) | |
# # adds first value | |
# mtcars_tbl %>% select(mpg, cyl) %>% mutate(info1 = rn[1]) | |
rn <- rownames(mtcars) | |
info <- data_frame(info = rn) | |
info_tbl <- copy_to(dest = sc, df = info) | |
# data management | |
mtcars_tbl <- sdf_bind_cols(mtcars_tbl, info_tbl) | |
mtcars_tbl %>% | |
select(mpg, cyl, info) %>% | |
head | |
# # Source: lazy query [?? x 3] | |
# # Database: spark_connection | |
# mpg cyl info | |
# <dbl> <dbl> <chr> | |
# 1 18.1 6.00 Valiant | |
# 2 14.3 8.00 Duster 360 | |
# 3 16.4 8.00 Merc 450SE | |
# 4 17.3 8.00 Merc 450SL | |
# 5 15.2 8.00 Merc 450SLC | |
# 6 19.2 8.00 Pontiac Firebird | |
# create test set | |
# list of tbl_spark objects | |
partitions <- sdf_partition(x = mtcars_tbl, | |
training = 0.7, | |
test = 0.3, | |
seed = 26325) | |
class(partitions) | |
# [1] "list" | |
# how many rows in each set | |
lapply(X = partitions, FUN = sdf_nrow) | |
#### Machine Learning #### | |
objects("package:sparklyr", pattern = "^ml_") | |
# fit a linear model to the training dataset in spark | |
fit1 <- ml_generalized_linear_regression( | |
x = partitions$training, | |
response = "am", | |
features = c("wt", "cyl"), | |
family = "binomial") | |
fit1 | |
class(fit1) | |
# [1] "ml_model_generalized_linear_regression" "ml_model_regression" "ml_model_prediction" | |
# [4] "ml_model" | |
# fitted values | |
# in spark (returns a tbl_spark) | |
fitted_tbl <- sdf_predict(fit1) | |
select(fitted_tbl, mpg, am, prediction) %>% | |
head | |
# confusion matrix fitted | |
summarise(group_by(mutate( | |
fitted_tbl, | |
pred = as.numeric(prediction > 0.5)), | |
am, pred), n = n()) | |
# # Source: lazy query [?? x 3] | |
# # Database: spark_connection | |
# # Groups: am | |
# am pred n | |
# <dbl> <dbl> <dbl> | |
# 1 0 1.00 1.00 | |
# 2 0 0 12.0 | |
# 3 1.00 1.00 8.00 | |
# 4 1.00 0 1.00 | |
library(tidyr) | |
library(ggplot2) | |
# dependent variable (am) c.f. prediction | |
tabsk <- select(fitted_tbl, am, prediction) %>% | |
collect %>% | |
gather(key = "type", value = "value") | |
tabsk <- mutate(tabsk, index = rep(seq_len(n() / 2), times = 2)) | |
# index | |
ind <- seq_len(sdf_nrow(fitted_tbl)) | |
# visual comparison by probability | |
ggplot(data = tabsk, | |
mapping = aes(x = type, y = -index, fill = value)) + | |
geom_tile() + | |
scale_y_continuous(breaks = -ind, | |
labels = ind) + | |
ylab("") + | |
ggtitle("Prediction comparison by row") | |
# predictions for test set | |
test_tbl <- sdf_predict( | |
x = partitions$test, | |
model = fit1) | |
# test predictions | |
select(test_tbl, mpg, am, prediction) | |
# confusion matrix | |
summarise(group_by(mutate( | |
test_tbl, | |
pred = as.numeric(prediction > 0.5)), | |
am, pred), n = n()) | |
# # Source: lazy query [?? x 3] | |
# # Database: spark_connection | |
# # Groups: am | |
# am pred n | |
# <dbl> <dbl> <dbl> | |
# 1 1.00 0 1.00 | |
# 2 0 0 6.00 | |
# 3 1.00 1.00 3.00 | |
# predict for arbitrary unlabelled data | |
new_sp <- data.frame("wt" = 2.2, "cyl" = 4) | |
new_tbl <- copy_to(dest = sc, df = new_sp) | |
# wrapper for ml_predict | |
predict(fit1, newdata = new_tbl) | |
# saving Spark ML object | |
ml_save(x = fit1, | |
path = "fit1_glm") | |
rm(fit1) | |
# load save model | |
fit1 <- ml_load(sc, path = "fit1_glm") | |
# note, class has changed | |
class(fit1) | |
# [1] "ml_pipeline_model" "ml_transformer" "ml_pipeline_stage" | |
# TODO FIXME | |
frst <- ml_random_forest( | |
x = partitions$training, | |
formula = NULL, | |
type = "classification", | |
features_col = colnames(mtcars_tbl)[-9], | |
label_col = "am") | |
ml_tree_feature_importance(sc = sc, model = frst) | |
fittedf_tbl <- sdf_predict(frst) | |
select(fittedf_tbl, mpg, am, id667155fbf, prediction) | |
ml_binary_classification_eval( | |
predicted_tbl_spark = fittedf_tbl, | |
label = "am", | |
score = "probability") | |
# [1] 1 | |
# confusion matrix | |
summarise(group_by(mutate( | |
fittedf_tbl, | |
pred = as.numeric(prediction > 0.5)), | |
am, pred), n = n()) | |
# # Source: lazy query [?? x 3] | |
# # Database: spark_connection | |
# # Groups: am | |
# am pred n | |
# <dbl> <dbl> <dbl> | |
# 1 0 0 12 | |
# 2 1 1 8 | |
testf_tbl <- sdf_predict(frst, newdata = partitions$test) | |
select(testf_tbl, mpg, am, prediction) | |
# outcome is binary | |
# get back the area under the ROC curve | |
ml_binary_classification_eval( | |
predicted_tbl_spark = testf_tbl, | |
label = "am", | |
score = "probability") | |
# [1] 0.9285714 | |
summarise(group_by(mutate( | |
testf_tbl, | |
pred = as.numeric(prediction > 0.5)), | |
am, pred), n = n()) | |
# # Source: lazy query [?? x 3] | |
# # Database: spark_connection | |
# # Groups: am | |
# am pred n | |
# <dbl> <dbl> <dbl> | |
# 1 0 0 6 | |
# 2 1 0 1 | |
# 3 1 1 4 | |
# 4 0 1 1 | |
# remove table from Spark | |
db_drop_table(sc, table = "mtcars_sp") | |
#### Data Download #### | |
# robust data download necessary | |
capabilities("libcurl") | |
# libcurl | |
# TRUE | |
u1 <- "https://data.cms.gov/api/views/sk9b-znav/rows.csv?accessType=DOWNLOAD" | |
f1 <- "Medicare_Provider_Payment_Data__Physician_PUF_CY2015.csv" | |
curl <- switch(.Platform$OS.type, | |
"windows" = "curl", "unix" = "libcurl") | |
system.time(download.file(url = u1, | |
destfile = f1, method = curl)) | |
# user system elapsed | |
# 0.86 1.55 1917.69 | |
# import | |
cost_tbl <- spark_read_csv(sc = sc, name = "cost", path = f1) | |
# handle allows processing in R | |
# but we are not updating object in spark here | |
summarise(group_by(cost_tbl, Medicare_Participation_Indicator), n = n()) | |
types <- summarise(group_by(cost_tbl, Provider_Type), n = n()) | |
print(types, n = 6) | |
# # Source: lazy query [?? x 2] | |
# # Database: spark_connection | |
# Provider_Type n | |
# <chr> <dbl> | |
# 1 Pathology 62976 | |
# 2 Family Practice 410449 | |
# 3 Obstetrics/Gynecology 45442 | |
# 4 General Surgery 67803 | |
# 5 Ophthalmology 101484 | |
# 6 Endocrinology 22640 | |
q1 <- filter(cost_tbl, Provider_Type == "Endocrinology") | |
# lazy query | |
object.size(q1) | |
# 11984 bytes | |
# get data from Spark | |
d1 <- as.data.frame(q1) | |
object.size(d1) | |
# 11795424 bytes | |
#### Disconnect #### | |
spark_disconnect(sc) | |
#### References #### | |
# http://spark.rstudio.com/ | |
# https://cwiki.apache.org/confluence/display/Hive/LanguageManual+UDF |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment