Throughout this interpretability series, I’ll mainly focus on model-agnostic methods which are appropriate for any type of model from linear regression to multilayer perceptrons. The following, partial dependence plot, is a model-agnostic method.
Idea is very simple. Actually, I remember making use of it before knowing any method although I wasn’t aware of its limitations… Below, I have a model trained on Euroleague play-by-play data, which uses cumulative three point attempts (cum_P3A), cumulative three point makes (cum_P3M), shot angle, shot zone area to predict whether the shot will go in or not in the upcoming three point attempt.
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
Let’s say we want to know the affect of cum_P3A on the prediction. We create a grid of values for cum_P3A and for each observation in our dataset, we set the value of cum_P3A to the first value of the grid. We keep all the other variables the same, and make predictions via our model for each row. Then, average those.
So, for example, for the value of 5 we go through each row within the dataframe and set the cum_P3A to 5 and make predictions and average those.
When we repeat the process for each value in the grid, we get the partial dependence function for cum_P3A:
So, that’s the partial dependence plot (PDP) for a cum_3PA. There are two problems that should worry you:
We miss out on heterogenous effects. A relatively flat line might be due to not having any affect at all or due to heterogenous effects canceling each other out. Individual conditional expectation (ICE) plots can help with that, which I show below.
I didn’t choose cum_3PA arbitrarily. We don’t take joint distribution into account with PDP which may result us ending up with counterfactual rows that are not likely or even possible in this case. For example, for a row that has 10 as cum_3PM, having anything less than 10 for cum_3PA doesn’t make any sense at all but we allow that. I’ll show the joint distribution but before that, let’s start with ICEs.
Those black lines are ICEs, where each line corresponds to a single row. The orange line is the partial dependence curve. They are related since PDP is the average of those ICEs.
Luckily, we don’t have to write those codes ourselves (I only did that for demonstrative purposes), sklearn has a built in function for it.
from sklearn.inspection import PartialDependenceDisplay
/opt/anaconda3/envs/pymc_env/lib/python3.12/site-packages/sklearn/inspection/_partial_dependence.py:717: FutureWarning: The column 2 contains integer data. Partial dependence plots are not supported for integer data: this can lead to implicit rounding with NumPy arrays or even errors with newer pandas versions. Please convert numerical featuresto floating point dtypes ahead of time to avoid problems. This will raise ValueError in scikit-learn 1.9.
warnings.warn(
Centering the ICEs makes it easier to compare the effect of a variable between different rows, so centered=True. Also, the rugs on the bottom helps with the variable’s distribution so that you can understand where the distribution is more dense and interpret accordingly. However, that doesn’t take joint distribution into account:
Marginal distribution for cum_P3A covers a wide range, as it is apparent from the x-axis of the scatter plot. However, when we condition on, let’s say cum_P3M = 30, things differ:
Some values for cum_P3A doesn’t make any sense at all, like having lower number of attempts than made shots. Some others are highly unlikely, like making 30 in 30 attempts. I’ll talk about Accumulated Local Effects (ALE) in a different post, where it addresses the correlation issue.
Despite its limitations, I like PDPs since the idea is very intuitive: Tracking the average prediction as I vary a particular variable. Also, you can make use of it in aggregate or certain portions of the data. ICEs are nice as well, although the plot gets extremely crowded if you have a relatively big data (I used a small sample due to this issue, for demonstrative purposes).
As long as you keep the limitations in mind, you should be fine :)