
Train machine learning models with optional stacking and feature selection
Source:R/machine_learning.R
compute_features.training.ML.Rd
This function trains one or more machine learning models using repeated k-fold cross-validation, with optional model stacking and feature selection using Boruta. It supports stratified cross-validation, including the construction of k-folds stratified by cohorts when this information is available.
Usage
compute_features.training.ML(
features_train,
target_var,
trait.positive,
metric = "Accuracy",
stack,
k_folds = 10,
n_rep = 5,
feature.selection = FALSE,
seed,
LODO = FALSE,
n_boruta = 100,
boruta_fix = FALSE,
batch_id = NULL,
file_name = NULL,
ncores = NULL,
return = FALSE,
fold_construction_fun = NULL,
fold_construction_args = list()
)
Arguments
- features_train
A data frame containing the features used for training.
- target_var
A vector containing the target variable to predict.
- trait.positive
Value in
target_var
to be considered as the positive class.- metric
Character. Metric used for hyperparameter tuning and model selection. Supported values are
"Accuracy"
,"AUROC"
, and"AUPRC"
.- stack
Logical. Whether to perform model stacking. Default is
FALSE
.- k_folds
Integer. Number of folds to use in cross-validation.
- n_rep
Integer. Number of repetitions of the cross-validation.
- feature.selection
Logical. Whether to apply Boruta feature selection before model training. Default is
FALSE
.- seed
Integer. Random seed for reproducibility.
- LODO
Logical. If
TRUE
, constructs folds stratified by cohorts (Leave-One-Dataset-Out CV).- n_boruta
Integer. Number of iterations to run Boruta. Since Boruta involves randomness, repeated runs improve consistency. Default is 100.
- boruta_fix
Logical. Whether to fix Boruta’s internal parameters. See
compute_boruta()
for details.- batch_id
A vector indicating the cohort or batch for each sample (required only if
LODO = TRUE
).- file_name
Character. File name used to save plots in the
Results/
directory.- ncores
Integer. Number of cores to use for parallelization. If not given, detectCores() - 1 will be used.
- return
Logical. Whether to return and save the plots generated by the function.
- fold_construction_fun
Function. A custom function used to construct the cross-validation folds. It should return a list of training indices for each fold.
- fold_construction_args
List. Named list of additional arguments to pass to
fold_construction_fun
.