This repository has been archived by the owner on Jul 2, 2024. It is now read-only.
generated from tlverse/tlverse-workshops
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
131 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
################################################################################ | ||
# - Prediction tasks with sl3 | ||
# - Constructing and fitting a super learner (SL) in sl3 | ||
# - Looking up learner documentation | ||
# - Modifying learner parameters | ||
# - Difference between R6 (sl3) and S3 (most R packages) methods | ||
################################################################################ | ||
|
||
### 0. Load Data and R packages | ||
library(data.table) | ||
washb_data <- fread( | ||
paste0( | ||
"https://raw.githubusercontent.com/tlverse/tlverse-data/master/", | ||
"wash-benefits/washb_data.csv" | ||
), | ||
stringsAsFactors = TRUE | ||
) | ||
head(washb_data) | ||
|
||
library(devtools) | ||
install_github("tlverse/sl3@devel") | ||
|
||
### 1. Define the prediction task with `make_sl3_Task` | ||
library(sl3) | ||
task <- make_sl3_Task( | ||
data = washb_data, | ||
outcome = "whz", | ||
covariates = c("tr", "fracode", "month", "aged", "sex", "momage", "momedu", | ||
"momheight", "hfiacat", "Nlt18", "Ncomp", "watmin", "elec", | ||
"floor", "walls", "roof", "asset_wardrobe", "asset_table", | ||
"asset_chair", "asset_khat", "asset_chouki", "asset_tv", | ||
"asset_refrig", "asset_bike", "asset_moto", "asset_sewmach", | ||
"asset_mobile") | ||
) | ||
|
||
# let's examine the task | ||
task | ||
|
||
### 2. Instantiate the SL with `Lrnr_sl` | ||
sl3_list_learners(properties = "continuous") | ||
|
||
lrn_glm <- Lrnr_glm$new() | ||
lrn_mean <- Lrnr_mean$new() | ||
lrn_ridge <- Lrnr_glmnet$new(alpha = 0) | ||
lrn_lasso <- Lrnr_glmnet$new(alpha = 1) | ||
lrn_polspline <- Lrnr_polspline$new() | ||
lrn_earth <- Lrnr_earth$new() | ||
lrn_hal <- Lrnr_hal9001$new(max_degree = 2, num_knots = c(3,2), nfolds = 5) | ||
lrn_ranger <- Lrnr_ranger$new() | ||
lrn_xgb <- Lrnr_xgboost$new() | ||
lrn_gam <- Lrnr_gam$new() | ||
lrn_bayesglm <- Lrnr_bayesglm$new() | ||
stack <- Stack$new( | ||
lrn_glm, lrn_mean, lrn_ridge, lrn_lasso, lrn_polspline, lrn_earth, lrn_hal, | ||
lrn_ranger, lrn_xgb, lrn_gam, lrn_bayesglm | ||
) | ||
|
||
sl <- Lrnr_sl$new(learners = stack, metalearner = Lrnr_nnls$new()) | ||
|
||
### 3. Fit the SL to the task with `train` | ||
set.seed(4197) | ||
sl_fit <- sl$train(task) | ||
|
||
#### Additional functionality | ||
|
||
# cross-validated predictive performance | ||
cv_risk_table <- sl_fit$cv_risk(eval_fun = loss_squared_error) | ||
cv_risk_table[,c(1:3)] | ||
|
||
# cross-validated predictive performance of SL | ||
set.seed(569) | ||
cv_sl_fit <- cv_sl(lrnr_sl = sl_fit, task = task, eval_fun = loss_squared_error) | ||
cv_sl_fit$cv_risk[,c(1:2)] | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
################################################################################ | ||
# - Prediction tasks with sl3 | ||
# - Constructing and fitting a single learner in sl3 | ||
# - Learner documentation | ||
# - Modifying learner parameters | ||
# - Difference between R6 (sl3) and S3 (most R packages) methods | ||
################################################################################ | ||
|
||
### 0. Load Data and R packages | ||
library(sl3) | ||
library(data.table) | ||
data(cpp_imputed) # subset of data from collaborative perinatal project (CPP) | ||
|
||
### 1. Define the prediction task with `make_sl3_Task` | ||
task <- sl3_Task$new( | ||
cpp_imputed, outcome = "haz", | ||
covariates = c("apgar1","apgar5","parity","gagebrth","mage","meducyrs","sexn") | ||
) | ||
# let's examine the task | ||
task | ||
|
||
### 2. Instantiate the learner with `Lrnr_*$new()` | ||
earth_sl3 <- Lrnr_earth$new() | ||
# what is default? ?Lrnr_earth | ||
|
||
### 3. Fit the the learner to the task with `train` | ||
set.seed(4738) | ||
earth_fit_sl3 <- earth_sl3$train(task) | ||
|
||
### We may want to get predictions from the fitted learner with `predict` | ||
preds_earth_fit_sl3 <- earth_fit_sl3$predict(task) | ||
|
||
### specification and predictions from classic implementation, which uses S3 | ||
library(earth) | ||
set.seed(4738) | ||
earth_fit_classic <- earth(x = task$X, y = task$Y, degree = 2) | ||
preds_earth_fit_classic <- predict(earth_fit_classic, newdata = task$X, | ||
type = "response") | ||
|
||
############ check equality of predictions | ||
all.equal(preds_earth_fit_sl3, as.numeric(preds_earth_fit_classic)) | ||
|
||
################################################################################ | ||
# specify different earth arguments in the Lrnr wrapper | ||
earth_1way <- Lrnr_earth$new(degree = 1) | ||
set.seed(4738) | ||
earth_1way_fit_sl3 <- earth_1way$train(task) | ||
preds_earth_1way_fit_sl3 <- earth_1way_fit_sl3$predict(task) | ||
|
||
# get predictions from classic implementation | ||
set.seed(4738) | ||
earth_fit_classic_1way <- earth(x = task$X, y = task$Y) | ||
preds_fit_classic_1way <- predict(earth_fit_classic_1way, newdata = task$X, | ||
type = "response") | ||
|
||
############ check equality of predictions | ||
all.equal(preds_earth_1way_fit_sl3, as.numeric(preds_fit_classic_1way)) |
Binary file not shown.
Binary file not shown.