Skip to contents

This function trains one or more machine learning models using repeated k-fold cross-validation, with optional model stacking, feature selection, and support for both classification and survival tasks. It allows flexible cross-validation schemes, including:

  • Standard stratified k-fold cross-validation

  • Leave-One-Dataset-Out (LODO) stratified folds by cohort

  • User-defined custom fold construction via a fold_construction_fun

Usage

compute_features.training.ML(
  features_train,
  task_type = c("classification", "survival"),
  target_var = NULL,
  trait.positive = NULL,
  time_var = NULL,
  event_var = NULL,
  metric = "Accuracy",
  stack,
  k_folds = 10,
  n_rep = 5,
  LODO = FALSE,
  batch_id = NULL,
  file_name = NULL,
  ncores = NULL,
  return = FALSE,
  fold_construction_fun = NULL,
  fold_construction_args_fixed = NULL,
  fold_construction_args_tunable = NULL
)

Arguments

features_train

A data frame containing the features used for training (samples in rows, features in columns).

task_type

Character. Specifies the type of prediction task. Either "classification" or "survival".

target_var

Vector. The target variable to predict (required for classification tasks).

trait.positive

Value in target_var that represents the positive class (used for metrics like AUROC and AUPRC).

time_var

Character. The name of the survival time variable (required for survival models).

event_var

Character. The name of the event indicator variable (required for survival models; 1 = event occurred, 0 = censored).

metric

Character. Performance metric used for model selection and tuning. Supported values are:

  • "Accuracy" — classification accuracy

  • "AUROC" — area under the ROC curve

  • "AUPRC" — area under the precision-recall curve

  • "C-index" — concordance index (for survival tasks)

stack

Logical. Whether to perform model stacking (ensemble meta-learning). Default is FALSE.

k_folds

Integer. Number of folds to use for cross-validation.

n_rep

Integer. Number of repetitions for cross-validation (repeated CV).

LODO

Logical. If TRUE, constructs cross-validation folds stratified by cohort (Leave-One-Dataset-Out scheme).

batch_id

Character. Column name indicating cohort or batch membership for each sample. Required if LODO = TRUE.

file_name

Character. File name used for saving performance plots in the "Results/" directory.

ncores

Integer. Number of CPU cores to use for parallelization. Defaults to parallel::detectCores() - 1.

return

Logical. Whether to return and save the generated plots. Default is FALSE.

fold_construction_fun

Function. Optional user-defined function to construct cross-validation folds. This enables full control over how data splits and feature transformations are created. The function must accept a bestune argument:

  • If bestune = NULL, the function explores a parameter grid across folds (executed in parallel via foreach).

  • If bestune is provided, optimized parameters are applied to the full dataset to rebuild features before final training.

The fold constructor should save individual folds as "Results/fold_*.rds" objects containing:

  • train_data — training data for that fold

  • test_data — testing data for that fold

  • obs_test — observed target or survival outcomes

  • params — parameter combination used (if applicable)

fold_construction_args_fixed

List. Arguments passed to fold_construction_fun that remain fixed during both cross-validation and final training (e.g., annotation files, normalization flags, etc.).

fold_construction_args_tunable

List. Arguments passed to fold_construction_fun that define hyperparameters to tune during cross-validation. Each element should contain one or more candidate values.

Value

A list containing:

  • Trained model(s) or meta-learner (if stack = TRUE)

  • Feature set used for model training

  • Cross-validation performance results and plots

  • Best hyperparameter configuration (if applicable)

Details

The function supports both classification and survival analysis pipelines by setting task_type = "classification" or task_type = "survival".

The function supports:

  • Automatic feature preprocessing (e.g., correlation filtering, low-variance removal).

  • Parallelized cross-validation across folds and repetitions.

  • Integration with custom model pipelines (e.g., CellTFusion, pathway-based deconvolution).

  • Unified handling of both survival and classification models.

When a custom fold constructor is provided via fold_construction_fun, the default stratified k-fold logic is bypassed, and the function will instead iterate through all Results/fold_*.rds files generated by the custom routine. This allows hybrid pipelines combining biological preprocessing (e.g., CellTFusion) with downstream model fitting.