Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix subset covariates to support out of order covariates. Covariates … #430

Merged
merged 2 commits into from
Apr 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 10 additions & 22 deletions R/Lrnr_base.R
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,13 @@ Lrnr_base <- R6Class(
if (length(delta_idx) > 0) {
delta_missing <- task_covs_missing[delta_idx]
task_covs_missing <- task_covs_missing[-delta_idx]

delta_missing_data <- matrix(0, nrow(task$data), length(delta_idx))
colnames(delta_missing_data) <- delta_missing
cols <- task$add_columns(data.table(delta_missing_data))

} else{
cols <- task$column_names
}

# error when task is missing covariates
Expand All @@ -68,29 +75,10 @@ Lrnr_base <- R6Class(
)
}

# subset task covariates to only includes those in learner covariates
covs_subset <- intersect(task_covs, learner_covs)

# return updated task
if (length(delta_idx) == 0) {
# re-order the covariate subset to match order of learner covariates
ordered_covs_subset <- covs_subset[match(covs_subset, learner_covs)]
return(task$next_in_chain(covariates = ordered_covs_subset))
} else {
# incorporate missingness indicators in task covariates subset & sort
covs_subset_delta <- c(covs_subset, delta_missing)
ord_covs <- covs_subset_delta[match(covs_subset_delta, learner_covs)]

# incorporate missingness indicators in task data
delta_missing_data <- matrix(0, nrow(task$data), length(delta_idx))
colnames(delta_missing_data) <- delta_missing
cols <- task$add_columns(data.table(delta_missing_data))

return(task$next_in_chain(
covariates = ord_covs,
return(task$next_in_chain(
covariates = learner_covs,
column_names = cols
))
}
))
} else {
return(task)
}
Expand Down
4 changes: 3 additions & 1 deletion R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,12 @@ call_with_args <- function(fun, args, other_valid = list(), keep_all = FALSE,
# subset arguments to pass
args <- args[which(names(args) %in% all_valid)]

# don't warn on covariate param
invalid <- setdiff(invalid, "covariates")
# return warnings when dropping arguments
if (!silent & length(invalid) > 0) {
message(sprintf(
"Learner called function %s with unknown args: %s. These will be dropped.\nCheck the params supported by this learner.",
"Learner called function %s with unknown args: %s. These will be dropped.\nCheck the params supported by this learner.\n",
as.character(substitute(fun)), paste(invalid, collapse = ", ")
))
}
Expand Down
11 changes: 9 additions & 2 deletions tests/testthat/test-subset_covariates.R
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,14 @@ full_preds <- glm_fit_pre_subset$predict(task)
training_preds <- glm_fit_pre_subset$predict()
test_that("extra covariates in prediction set get dropped correctly", expect_equal(full_preds, training_preds))


shuffled_subset <- sample(covariate_subset)
task_pre_subset_shuffled <- sl3_Task$new(mtcars, covariates = shuffled_subset, outcome = outcome)
# debugonce(glm_fit_pre_subset$subset_covariates)
shuffled_preds <- glm_fit_pre_subset$predict(task_pre_subset_shuffled)
test_that("covariates out of order prediction set get shuffled correctly", expect_equal(full_preds, shuffled_preds))


task_train <- sl3_Task$new(mtcars, covariates = covariates, outcome = outcome)
task_predict <- sl3_Task$new(mtcars, covariates = covariate_subset, outcome = outcome)
glm_fit <- lrnr_glm$train(task_train)
Expand All @@ -47,11 +55,10 @@ task_missing_data <- suppressWarnings(
sl3_Task$new(missing_data, covariates = covs, outcome = Y)
)

lrnr_glm <- make_learner(Lrnr_glm_fast, name = "test")
lrnr_glm <- make_learner(Lrnr_glm_fast)
glm_fit <- lrnr_glm$train(task_missing_data)

task_complete_data <- sl3_Task$new(mtcars, covariates = covs, outcome = Y)

test_that("missingness indicators in prediction task works", {
expect_vector(glm_fit$predict(task_complete_data))
})
Loading