
Train model with optimized hyperparameters for classification tasks
Source:R/machine_learning.R
wrapper_train_best_hyperparams_classification.RdThis function wraps cross-validation, hyperparameter optimization, and final training into a single workflow. It identifies the best hyperparameters using a custom cross-validation function, reconstructs the training set, preprocesses features, and retrains the model on the complete training data with the selected hyperparameters.
Usage
wrapper_train_best_hyperparams_classification(
train_data,
optimized,
ml_method,
fold_construction_fun,
fold_construction_args_fixed
)Arguments
- train_data
A data frame containing the full training dataset, including predictors and the target variable.
- ml_method
A character string specifying the machine learning method to be passed to
caret::train.- fold_construction_fun
A function used to (re)construct training data partitions given the best hyperparameters.
- fold_construction_args_fixed
A named list of additional fixed arguments to pass to
fold_construction_fun.- fold_data
A list or object containing pre-constructed folds for cross-validation, typically created by
fold_construction_fun.- tuneGrid
(optional) A data frame of hyperparameter values to evaluate. If
NULL, defaults are used.- ncores
(optional) Integer specifying the number of cores for parallel processing during cross-validation. If
NULL, defaults to serial execution.
Value
A list with the following components:
ModelA trained
caretmodel object with results, predictions, and resampling info attached.training_setThe final preprocessed training dataset.
custom_outputAdditional output from
fold_construction_fun.
Details
The workflow proceeds in the following steps:
Runs cross-validation using
compute_custom_k_fold_CVto identify the best hyperparameter set.Reconstructs the training set using
fold_construction_funand the selected hyperparameters.Preprocesses the training features by removing near-zero variance, highly correlated, and constant-within-class features (via
preprocess_features).Retrains the model on the complete training data with the optimized hyperparameters.
The returned object mimics a caret model object but includes
additional elements derived from the custom cross-validation.
Examples
if (FALSE) { # \dontrun{
library(caret)
# Example placeholders
train_data <- your_training_data
fold_data <- your_prepared_folds
result <- wrapper_train_best_hyperparams_classification(
train_data = train_data,
fold_data = fold_data,
ml_method = "rf",
fold_construction_fun = your_fold_fun,
fold_construction_args_fixed = list(arg1 = "value"),
tuneGrid = expand.grid(mtry = 2:4),
ncores = 4
)
result$Model
} # }