Skip to content

Commit

Permalink
Merge pull request #430 from tlverse/fix-subset-covariates
Browse files Browse the repository at this point in the history
fix subset covariates to support out of order covariates. Covariates …
  • Loading branch information
jeremyrcoyle authored Apr 29, 2024
2 parents 507e0a1 + 5462dd6 commit fdfe83f
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 25 deletions.
32 changes: 10 additions & 22 deletions R/Lrnr_base.R
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,13 @@ Lrnr_base <- R6Class(
if (length(delta_idx) > 0) {
delta_missing <- task_covs_missing[delta_idx]
task_covs_missing <- task_covs_missing[-delta_idx]

delta_missing_data <- matrix(0, nrow(task$data), length(delta_idx))
colnames(delta_missing_data) <- delta_missing
cols <- task$add_columns(data.table(delta_missing_data))

} else{
cols <- task$column_names
}

# error when task is missing covariates
Expand All @@ -68,29 +75,10 @@ Lrnr_base <- R6Class(
)
}

# subset task covariates to only includes those in learner covariates
covs_subset <- intersect(task_covs, learner_covs)

# return updated task
if (length(delta_idx) == 0) {
# re-order the covariate subset to match order of learner covariates
ordered_covs_subset <- covs_subset[match(covs_subset, learner_covs)]
return(task$next_in_chain(covariates = ordered_covs_subset))
} else {
# incorporate missingness indicators in task covariates subset & sort
covs_subset_delta <- c(covs_subset, delta_missing)
ord_covs <- covs_subset_delta[match(covs_subset_delta, learner_covs)]

# incorporate missingness indicators in task data
delta_missing_data <- matrix(0, nrow(task$data), length(delta_idx))
colnames(delta_missing_data) <- delta_missing
cols <- task$add_columns(data.table(delta_missing_data))

return(task$next_in_chain(
covariates = ord_covs,
return(task$next_in_chain(
covariates = learner_covs,
column_names = cols
))
}
))
} else {
return(task)
}
Expand Down
4 changes: 3 additions & 1 deletion R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,12 @@ call_with_args <- function(fun, args, other_valid = list(), keep_all = FALSE,
# subset arguments to pass
args <- args[which(names(args) %in% all_valid)]

# don't warn on covariate param
invalid <- setdiff(invalid, "covariates")
# return warnings when dropping arguments
if (!silent & length(invalid) > 0) {
message(sprintf(
"Learner called function %s with unknown args: %s. These will be dropped.\nCheck the params supported by this learner.",
"Learner called function %s with unknown args: %s. These will be dropped.\nCheck the params supported by this learner.\n",
as.character(substitute(fun)), paste(invalid, collapse = ", ")
))
}
Expand Down
11 changes: 9 additions & 2 deletions tests/testthat/test-subset_covariates.R
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,14 @@ full_preds <- glm_fit_pre_subset$predict(task)
training_preds <- glm_fit_pre_subset$predict()
test_that("extra covariates in prediction set get dropped correctly", expect_equal(full_preds, training_preds))


shuffled_subset <- sample(covariate_subset)
task_pre_subset_shuffled <- sl3_Task$new(mtcars, covariates = shuffled_subset, outcome = outcome)
# debugonce(glm_fit_pre_subset$subset_covariates)
shuffled_preds <- glm_fit_pre_subset$predict(task_pre_subset_shuffled)
test_that("covariates out of order prediction set get shuffled correctly", expect_equal(full_preds, shuffled_preds))


task_train <- sl3_Task$new(mtcars, covariates = covariates, outcome = outcome)
task_predict <- sl3_Task$new(mtcars, covariates = covariate_subset, outcome = outcome)
glm_fit <- lrnr_glm$train(task_train)
Expand All @@ -47,11 +55,10 @@ task_missing_data <- suppressWarnings(
sl3_Task$new(missing_data, covariates = covs, outcome = Y)
)

lrnr_glm <- make_learner(Lrnr_glm_fast, name = "test")
lrnr_glm <- make_learner(Lrnr_glm_fast)
glm_fit <- lrnr_glm$train(task_missing_data)

task_complete_data <- sl3_Task$new(mtcars, covariates = covs, outcome = Y)

test_that("missingness indicators in prediction task works", {
expect_vector(glm_fit$predict(task_complete_data))
})

0 comments on commit fdfe83f

Please sign in to comment.