shapfire.shapfire
¶
This module contains the main implementation of the ShapFire method for feature ranking and selection.
-
shapfire.shapfire.DEFAULT_REPEATS : int =
1
¶ The number of times, in a cross-validation, the division of a dataset into a certain number of folds should be repeated.
-
shapfire.shapfire.DEFAULT_SPLITS : int =
2
¶ The default number of folds a dataset should be divided into in a cross-validation.
-
class shapfire.ShapFire(estimator_class: LGBMClassifier | LGBMRegressor | RandomForestClassifier | RandomForestRegressor, scoring: str, estimator_params: None | dict[str, Any] =
None
, n_splits: int =DEFAULT_SPLITS
, n_repeats: int =DEFAULT_REPEATS
, random_seed: int =utils.DEFAULT_RANDOM_SEED
, iterations: None | int =None
, reference: str ='mean'
, n_samples: int =1000
, n_batches: int =250
)[source]¶ The main class used for applying SHAP feature importance rank ensembling for feature selection.
- Parameters¶:
- estimator_class: LGBMClassifier | LGBMRegressor | RandomForestClassifier | RandomForestRegressor¶
The scikit-learn or Microsoft LightGBM tree-based estimator to use. The estimator can either be a classifier or a regressor.
- scoring: str¶
The specification of a scoring function to use for model-evaluation, i.e., a function that can be used for assessing the prediction error of a trained model given a test set.
- estimator_params: None | dict[str, Any] =
None
¶ The estimator hyperparameters and corresponding values to search or directly use. If only a single value for each hyperparameter is provided then only cross-validation will be performed and no hyperparameter search will be performed. Defaults to None.
- n_splits: int =
DEFAULT_SPLITS
¶ The number of folds to generate in the outer loop of a nested cross-validation. Defaults to
shapfire.shapfire.DEFAULT_SPLITS
.- n_repeats: int =
DEFAULT_REPEATS
¶ The number of new folds that should be generated in the outer loop of a nested cross-validation. Defaults to
shapfire.shapfire.DEFAULT_REPEATS
.- random_seed: int =
utils.DEFAULT_RANDOM_SEED
¶ The random seed to use for reproducibility purposes. Defaults to
shapfire.utils.DEFAULT_RANDOM_SEED
.- iterations: None | int =
None
¶ The number of feature subsets to sample and subsequently use for model-training such that SHAP feature importance values can be extracted. Defaults to None which in turns sets the number of iterations to the size of the largest cluster of highly associated features found.
- reference: str =
'mean'
¶ The data fusion method to use for producing a reference vector. Defaults to “mean”.
- n_samples: int =
1000
¶ The number of random samples of rank permutations to use in a batch. Several batches of random samples are used to estimate a “ranking distribution” which in turn is used to determine a feature importance cut-off threshold. Defaults to 1000.
- n_batches: int =
250
¶ The number of batches of random samples that should be used to estimate a “ranking distribution” which in turn is used to determine a feature importance cut-off threshold. Defaults to 250.
- ranked_differences¶
A class attribute and pandas dataframe that specifies the final importance values associated with each of the features in the given input dataset.
- selected_features¶
A ShapFire class attribute and list that specifies the final feature subset selected by ShapFire and which is expected to achieve the best possible model performance.
- fit(X: ndarray | DataFrame, y: ndarray | DataFrame) ShapFire [source]¶
Perform SHAP feature importance rank ensembling for the purpose of ranking and selecting the features that can be said to be the most important for a certain prediction task at hand.
- transform(X: ndarray | DataFrame) ndarray | DataFrame [source]¶
Reduce the input dataset X containing features (columns) and corresponding observations (rows), to only the columns of features selected by ShapFire.
- Parameters¶:
- Raises¶:
ValueError – If the
fit()
method has not yet been called.- Returns¶:
A reduced dataset that only contains the most important features (columns).
- fit_transform(X: ndarray | DataFrame, y: ndarray | DataFrame) ndarray | DataFrame [source]¶
Perform SHAP feature importance rank ensembling for the purpose of selecting the features that are the most important. Subsequently, reduce the input dataset ‘X’ to only the columns of the selected features.
-
plot_ranking(groupby: str =
'cluster'
, rcParams: None | dict[str, str] =None
, figsize: None | tuple[float, float] =None
, fontsize: int =10
, with_text: bool =True
, with_overlay: bool =True
, ax: None | Axes =None
) tuple[matplotlib.figure.Figure, matplotlib.axes._axes.Axes] [source]¶ Plot the feature importance scores associated with each feature. The features will be ordered in the figure from best to worst and possibly according to which cluster they each belong to.
- Parameters¶:
- groupby: str =
'cluster'
¶ A string value indicating how the feature importance ranking should be displayed in a figure. If the option ‘cluster’ is chosen, then the features are grouped and shown in the figure based on their assigned cluster and according to the importance rank of the best feautre in the cluster. If ‘feature’ is chosen, then the features are shown in the figure purely according to their global rank without any consideration to what cluster each features are a part of.
- figsize: None | tuple[float, float] =
None
¶ The width and height of the figure in inches. Defaults to None.
- fontsize: int =
10
¶ The size of the font present in the figure. Defaults to 10.
- with_text: bool =
True
¶ If input argument
groupby
is set to ‘cluster’, thenwith_text
determines whether features that have been grouped in the figure by the cluster they each belong to, should also be annotated with a text label. Defaults to True.- with_overlay: bool =
True
¶ Depending on whether
groupby
is set to ‘cluster’ or ‘feature’, groups of features or individual features are assigned a gray-scale overlay creating a visual grouping / delimitation of features. Defaults to True.- ax: None | Axes =
None
¶ A Matplotlib Axes object. Defaults to None.
- groupby: str =
- Returns¶:
A Matplotlib Figure and Axes object.