
Compute SHAP Values for Machine Learning Models
compute_shap_values.RdThis function calculates SHAP (SHapley Additive exPlanations) values to assess feature importance for a trained machine learning model. It supports both classification and survival tasks, and performs calculations on cross-validation resamples in parallel. The results can be summarized and optionally saved with a stability plot.
Usage
compute_shap_values(
model_trained,
data_train,
task_type = "classification",
target_col = NULL,
trait.positive,
time_col = NULL,
event_col = NULL,
n_cores = 2,
file.name = NULL
)Arguments
- model_trained
A trained machine learning model object (e.g., output from caret or custom ML pipeline), which includes cross-validation resamples.
- data_train
A data frame containing the training data used for the model.
- task_type
Character. Either
"classification"(default) or"survival".- target_col
Character. Name of the target column for classification tasks. Required if
task_type = "classification".- trait.positive
Value representing the positive class in classification tasks.
- time_col
Character. Column name representing survival time. Required if
task_type = "survival".- event_col
Character. Column name representing survival event indicator. Required if
task_type = "survival".- n_cores
Integer. Number of cores for parallel computation. Default is 2.
- file.name
Character. Optional filename prefix for saving SHAP stability plots. If
NULL, plots are not saved.
Value
A data frame containing SHAP values for all features, averaged across resamples, with rows corresponding to training samples and columns to features.
Details
The function performs the following steps:
Sets up classification or survival prediction functions based on the task type.
Loops over all cross-validation resamples in parallel, refitting models on training folds.
Computes SHAP values using
fastshap::explain()for each resample, skipping trivial predictions.Combines SHAP values across resamples and summarizes them (median per sample).
Generates and optionally saves a SHAP stability plot if
file.nameis provided.
Trivial predictions (constant probability for all samples) are skipped, and a warning is issued if SHAP values cannot be computed.