
Train machine learning or survival models with optional stacking and custom cross-validation
Source:R/machine_learning.R
compute_features.training.ML.RdThis function trains one or more machine learning models using repeated k-fold cross-validation, with optional model stacking, feature selection, and support for both classification and survival tasks. It allows flexible cross-validation schemes, including:
Standard stratified k-fold cross-validation
Leave-One-Dataset-Out (LODO) stratified folds by cohort
User-defined custom fold construction via a
fold_construction_fun
Usage
compute_features.training.ML(
features_train,
task_type = c("classification", "survival"),
target_var = NULL,
trait.positive = NULL,
time_var = NULL,
event_var = NULL,
metric = "Accuracy",
stack,
k_folds = 10,
n_rep = 5,
LODO = FALSE,
batch_id = NULL,
file_name = NULL,
ncores = NULL,
return = FALSE,
fold_construction_fun = NULL,
fold_construction_args_fixed = NULL,
fold_construction_args_tunable = NULL
)Arguments
- features_train
A data frame containing the features used for training (samples in rows, features in columns).
- task_type
Character. Specifies the type of prediction task. Either
"classification"or"survival".- target_var
Vector. The target variable to predict (required for classification tasks).
- trait.positive
Value in
target_varthat represents the positive class (used for metrics like AUROC and AUPRC).- time_var
Character. The name of the survival time variable (required for survival models).
- event_var
Character. The name of the event indicator variable (required for survival models; 1 = event occurred, 0 = censored).
- metric
Character. Performance metric used for model selection and tuning. Supported values are:
"Accuracy"— classification accuracy"AUROC"— area under the ROC curve"AUPRC"— area under the precision-recall curve"C-index"— concordance index (for survival tasks)
- stack
Logical. Whether to perform model stacking (ensemble meta-learning). Default is
FALSE.- k_folds
Integer. Number of folds to use for cross-validation.
- n_rep
Integer. Number of repetitions for cross-validation (repeated CV).
- LODO
Logical. If
TRUE, constructs cross-validation folds stratified by cohort (Leave-One-Dataset-Out scheme).- batch_id
Character. Column name indicating cohort or batch membership for each sample. Required if
LODO = TRUE.- file_name
Character. File name used for saving performance plots in the
"Results/"directory.- ncores
Integer. Number of CPU cores to use for parallelization. Defaults to
parallel::detectCores() - 1.- return
Logical. Whether to return and save the generated plots. Default is
FALSE.- fold_construction_fun
Function. Optional user-defined function to construct cross-validation folds. This enables full control over how data splits and feature transformations are created. The function must accept a
bestuneargument:If
bestune = NULL, the function explores a parameter grid across folds (executed in parallel viaforeach).If
bestuneis provided, optimized parameters are applied to the full dataset to rebuild features before final training.
The fold constructor should save individual folds as
"Results/fold_*.rds"objects containing:train_data— training data for that foldtest_data— testing data for that foldobs_test— observed target or survival outcomesparams— parameter combination used (if applicable)
- fold_construction_args_fixed
List. Arguments passed to
fold_construction_funthat remain fixed during both cross-validation and final training (e.g., annotation files, normalization flags, etc.).- fold_construction_args_tunable
List. Arguments passed to
fold_construction_funthat define hyperparameters to tune during cross-validation. Each element should contain one or more candidate values.
Value
A list containing:
Trained model(s) or meta-learner (if
stack = TRUE)Feature set used for model training
Cross-validation performance results and plots
Best hyperparameter configuration (if applicable)
Details
The function supports both classification and survival analysis pipelines by setting
task_type = "classification" or task_type = "survival".
The function supports:
Automatic feature preprocessing (e.g., correlation filtering, low-variance removal).
Parallelized cross-validation across folds and repetitions.
Integration with custom model pipelines (e.g., CellTFusion, pathway-based deconvolution).
Unified handling of both survival and classification models.
When a custom fold constructor is provided via fold_construction_fun, the default stratified k-fold
logic is bypassed, and the function will instead iterate through all Results/fold_*.rds files generated
by the custom routine. This allows hybrid pipelines combining biological preprocessing (e.g., CellTFusion)
with downstream model fitting.