InterpretML supports training interpretable models (glassbox), as well as explaining existing ML pipelines (blackbox).
Let’s walk through an example of each using the UCI adult income classification dataset.
Download and Prepare Data
First, we will load the data into a standard pandas dataframe or a numpy array, and create a train / test split. There’s no special preprocessing necessary to use your data with InterpretML.
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
Parameters
feature_names
None
feature_types
None
max_bins
1024
max_interaction_bins
64
interactions
'3x'
exclude
None
validation_size
0.15
outer_bags
14
inner_bags
0
learning_rate
0.015
greedy_ratio
10.0
cyclic_progress
False
smoothing_rounds
75
interaction_smoothing_rounds
75
max_rounds
50000
early_stopping_rounds
100
early_stopping_tolerance
1e-05
callback
None
min_samples_leaf
4
min_hessian
0.0001
reg_alpha
0.0
reg_lambda
0.0
max_delta_step
0.0
gain_scale
5.0
min_cat_samples
10
cat_smooth
10.0
missing
'separate'
max_leaves
2
monotone_constraints
None
objective
'log_loss'
n_jobs
-2
random_state
42
Explain the Glassbox
Glassbox models can provide explanations on a both global (overall behavior) and local (individual predictions) level.
Global explanations are useful for understanding what a model finds important, as well as identifying potential flaws in its decision making (i.e. racial bias).
The inline visualization embedded here are exactly what gets produced in the notebook.
For this global explanation, the initial summary page shows the most important features overall. You can use the dropdown to search, filter, and select individual features to drill down deeper into.
Try looking at the “Age” feature to see how the probability of high income varies with Age, or the “Race” or “Gender” features to observe potential bias the model may have learned.
frominterpretimportshowshow(ebm.explain_global())
Local explanations show how a single prediction is made. For glassbox models, these explanations are exact – they perfectly describe how the model made its decision.
These explanations are useful for describing to end users which factors were most influential for a prediction.
In the local explanation below for instance “2”, the probability of high income was 0.93, largely due to having a high value for the CapitalGains feature.
The values shown here are log-odds scores from the EBM, which are added and passed through a logistic-link function to get the final prediction, just like logistic regression.
show(ebm.explain_local(X_test[:5],y_test[:5]),0)
Build a Blackbox Pipeline
Blackbox interpretability methods can extract explanations from any machine learning pipeline. This includes model ensembles, pre-processing steps, and complex models such as deep neural nets.
Let’s start by training a random forest that is first pre-processed with principal component analysis.
fromsklearn.ensembleimportRandomForestClassifierfromsklearn.decompositionimportPCAfromsklearn.pipelineimportPipeline# We have to transform categorical variables to use sklearn modelsX=pd.get_dummies(X,prefix_sep='.').astype(float)X_train,X_test,y_train,y_test=train_test_split(X,y,test_size=0.20,random_state=seed)#Blackbox system can include preprocessing, not just a classifier!pca=PCA()rf=RandomForestClassifier(random_state=seed)blackbox_model=Pipeline([('pca',pca),('rf',rf)])blackbox_model.fit(X_train,y_train)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
Parameters
steps
steps: list of tuples
List of (name of step, estimator) tuples that are to be chained in sequential order. To be compatible with the scikit-learn API, all steps must define `fit`. All non-last steps must also define `transform`. See :ref:`Combining Estimators ` for more details.
[('pca', ...), ('rf', ...)]
transform_input
transform_input: list of str, default=None
The names of the :term:`metadata` parameters that should be transformed by the pipeline before passing it to the step consuming it.
This enables transforming some input arguments to ``fit`` (other than ``X``) to be transformed by the steps of the pipeline up to the step which requires them. Requirement is defined via :ref:`metadata routing `. For instance, this can be used to pass a validation set through the pipeline.
You can only set this if metadata routing is enabled, which you can enable using ``sklearn.set_config(enable_metadata_routing=True)``.
The number of jobs to run in parallel. :meth:`fit`, :meth:`predict`, :meth:`decision_path` and :meth:`apply` are all parallelized over the trees. ``None`` means 1 unless in a :obj:`joblib.parallel_backend` context. ``-1`` means using all processors. See :term:`Glossary ` for more details.
None
random_state
random_state: int, RandomState instance or None, default=None
Controls both the randomness of the bootstrapping of the samples used when building trees (if ``bootstrap=True``) and the sampling of the features to consider when looking for the best split at each node (if ``max_features < n_features``). See :term:`Glossary ` for details.
When set to ``True``, reuse the solution of the previous call to fit and add more estimators to the ensemble, otherwise, just fit a whole new forest. See :term:`Glossary ` and :ref:`tree_ensemble_warm_start` for details.
monotonic_cst
monotonic_cst: array-like of int of shape (n_features), default=None
Indicates the monotonicity constraint to enforce on each feature. - 1: monotonic increase - 0: no constraint - -1: monotonic decrease
If monotonic_cst is None, no constraints are applied.
Monotonicity constraints are not supported for: - multiclass classifications (i.e. when `n_classes > 2`), - multioutput classifications (i.e. when `n_outputs_ > 1`), - classifications trained on data with missing values.
The constraints hold over the probability of the positive class.
Read more in the :ref:`User Guide `.
.. versionadded:: 1.4
None
Explain the Blackbox
All you need for a blackbox interpretability method is a predict function from the target ML pipeline.
Blackbox interpretability methods generally work by perturbing input data repeatedly passing it through the pipeline, and observing how the final prediction changes.
As a result both global and local explanations are approximate, and may sometimes be inaccurate. Be cautious of the results in high-stakes environments.