Skip to content
This repository has been archived by the owner on Jul 2, 2024. It is now read-only.

Commit

Permalink
add slide decks and coding
Browse files Browse the repository at this point in the history
  • Loading branch information
rachaelvp committed Aug 5, 2023
1 parent 16513e3 commit ef34ecc
Show file tree
Hide file tree
Showing 4 changed files with 131 additions and 0 deletions.
74 changes: 74 additions & 0 deletions live coding R code/fit_sl.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
################################################################################
# - Prediction tasks with sl3
# - Constructing and fitting a super learner (SL) in sl3
# - Looking up learner documentation
# - Modifying learner parameters
# - Difference between R6 (sl3) and S3 (most R packages) methods
################################################################################

### 0. Load Data and R packages
library(data.table)
washb_data <- fread(
paste0(
"https://raw.githubusercontent.com/tlverse/tlverse-data/master/",
"wash-benefits/washb_data.csv"
),
stringsAsFactors = TRUE
)
head(washb_data)

library(devtools)
install_github("tlverse/sl3@devel")

### 1. Define the prediction task with `make_sl3_Task`
library(sl3)
task <- make_sl3_Task(
data = washb_data,
outcome = "whz",
covariates = c("tr", "fracode", "month", "aged", "sex", "momage", "momedu",
"momheight", "hfiacat", "Nlt18", "Ncomp", "watmin", "elec",
"floor", "walls", "roof", "asset_wardrobe", "asset_table",
"asset_chair", "asset_khat", "asset_chouki", "asset_tv",
"asset_refrig", "asset_bike", "asset_moto", "asset_sewmach",
"asset_mobile")
)

# let's examine the task
task

### 2. Instantiate the SL with `Lrnr_sl`
sl3_list_learners(properties = "continuous")

lrn_glm <- Lrnr_glm$new()
lrn_mean <- Lrnr_mean$new()
lrn_ridge <- Lrnr_glmnet$new(alpha = 0)
lrn_lasso <- Lrnr_glmnet$new(alpha = 1)
lrn_polspline <- Lrnr_polspline$new()
lrn_earth <- Lrnr_earth$new()
lrn_hal <- Lrnr_hal9001$new(max_degree = 2, num_knots = c(3,2), nfolds = 5)
lrn_ranger <- Lrnr_ranger$new()
lrn_xgb <- Lrnr_xgboost$new()
lrn_gam <- Lrnr_gam$new()
lrn_bayesglm <- Lrnr_bayesglm$new()
stack <- Stack$new(
lrn_glm, lrn_mean, lrn_ridge, lrn_lasso, lrn_polspline, lrn_earth, lrn_hal,
lrn_ranger, lrn_xgb, lrn_gam, lrn_bayesglm
)

sl <- Lrnr_sl$new(learners = stack, metalearner = Lrnr_nnls$new())

### 3. Fit the SL to the task with `train`
set.seed(4197)
sl_fit <- sl$train(task)

#### Additional functionality

# cross-validated predictive performance
cv_risk_table <- sl_fit$cv_risk(eval_fun = loss_squared_error)
cv_risk_table[,c(1:3)]

# cross-validated predictive performance of SL
set.seed(569)
cv_sl_fit <- cv_sl(lrnr_sl = sl_fit, task = task, eval_fun = loss_squared_error)
cv_sl_fit$cv_risk[,c(1:2)]

57 changes: 57 additions & 0 deletions live coding R code/sl3_intro.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
################################################################################
# - Prediction tasks with sl3
# - Constructing and fitting a single learner in sl3
# - Learner documentation
# - Modifying learner parameters
# - Difference between R6 (sl3) and S3 (most R packages) methods
################################################################################

### 0. Load Data and R packages
library(sl3)
library(data.table)
data(cpp_imputed) # subset of data from collaborative perinatal project (CPP)

### 1. Define the prediction task with `make_sl3_Task`
task <- sl3_Task$new(
cpp_imputed, outcome = "haz",
covariates = c("apgar1","apgar5","parity","gagebrth","mage","meducyrs","sexn")
)
# let's examine the task
task

### 2. Instantiate the learner with `Lrnr_*$new()`
earth_sl3 <- Lrnr_earth$new()
# what is default? ?Lrnr_earth

### 3. Fit the the learner to the task with `train`
set.seed(4738)
earth_fit_sl3 <- earth_sl3$train(task)

### We may want to get predictions from the fitted learner with `predict`
preds_earth_fit_sl3 <- earth_fit_sl3$predict(task)

### specification and predictions from classic implementation, which uses S3
library(earth)
set.seed(4738)
earth_fit_classic <- earth(x = task$X, y = task$Y, degree = 2)
preds_earth_fit_classic <- predict(earth_fit_classic, newdata = task$X,
type = "response")

############ check equality of predictions
all.equal(preds_earth_fit_sl3, as.numeric(preds_earth_fit_classic))

################################################################################
# specify different earth arguments in the Lrnr wrapper
earth_1way <- Lrnr_earth$new(degree = 1)
set.seed(4738)
earth_1way_fit_sl3 <- earth_1way$train(task)
preds_earth_1way_fit_sl3 <- earth_1way_fit_sl3$predict(task)

# get predictions from classic implementation
set.seed(4738)
earth_fit_classic_1way <- earth(x = task$X, y = task$Y)
preds_fit_classic_1way <- predict(earth_fit_classic_1way, newdata = task$X,
type = "response")

############ check equality of predictions
all.equal(preds_earth_1way_fit_sl3, as.numeric(preds_fit_classic_1way))
Binary file added slides/JSM2023_TLintro.pdf
Binary file not shown.
Binary file added slides/SL.pdf
Binary file not shown.

0 comments on commit ef34ecc

Please sign in to comment.