partial_dependence {effectplots}R Documentation

Partial Dependence

Description

Calculates PD for one or multiple X variables.

PD was introduced by Friedman (2001) to study the (main) effects of a ML model. PD of a model f and variable X at a certain value g is derived by replacing the X values in a reference data by g, and then calculating the average prediction of f over this modified data. This is done for different g to see how the average prediction of f changes in X, keeping all other feature values constant (Ceteris Paribus).

This function is a convenience wrapper around feature_effects(), which calls the barebone implementation .pd() to calculate PD. As grid points, it uses the arithmetic mean of X per bin (specified by breaks), and eventually weighted by w.

Usage

partial_dependence(object, ...)

## Default S3 method:
partial_dependence(
  object,
  v,
  data,
  pred_fun = stats::predict,
  trafo = NULL,
  which_pred = NULL,
  w = NULL,
  breaks = "Sturges",
  right = TRUE,
  discrete_m = 5L,
  outlier_iqr = 2,
  pd_n = 500L,
  seed = NULL,
  ...
)

## S3 method for class 'ranger'
partial_dependence(
  object,
  v,
  data,
  pred_fun = NULL,
  trafo = NULL,
  which_pred = NULL,
  w = NULL,
  breaks = "Sturges",
  right = TRUE,
  discrete_m = 5L,
  outlier_iqr = 2,
  pd_n = 500L,
  seed = NULL,
  ...
)

## S3 method for class 'explainer'
partial_dependence(
  object,
  v = colnames(data),
  data = object$data,
  pred_fun = object$predict_function,
  trafo = NULL,
  which_pred = NULL,
  w = object$weights,
  breaks = "Sturges",
  right = TRUE,
  discrete_m = 5L,
  outlier_iqr = 2,
  pd_n = 500L,
  seed = NULL,
  ...
)

Arguments

object

Fitted model.

...

Further arguments passed to pred_fun(), e.g., type = "response" in a glm() or (typically) prob = TRUE in classification models.

v

Vector of variable names to calculate statistics.

data

Matrix or data.frame.

pred_fun

Prediction function, by default stats::predict. The function takes three arguments (names irrelevant): object, data, and ....

trafo

How should predictions be transformed? A function or NULL (default). Examples are log (to switch to link scale) or exp (to switch from link scale to the original scale).

which_pred

If the predictions are multivariate: which column to pick (integer or column name). By default NULL (picks last column).

w

Optional vector with case weights. Can also be a column name in data.

breaks

An integer, vector, string or function specifying the bins of the numeric X variables as in graphics::hist(). The default is "Sturges". To allow varying values of breaks across variables, it can be a list of the same length as v, or a named list with breaks for certain variables.

right

Should bins be right-closed? The default is TRUE. Vectorized over v. Only relevant for numeric X.

discrete_m

Numeric X variables with up to this number of unique values should not be binned and treated as a factor (after calculating partial dependence) The default is 5. Vectorized over v.

outlier_iqr

Outliers of a numeric X are capped via the boxplot rule, i.e., outside outlier_iqr * IQR from the quartiles. The default is 2 is more conservative than the usual rule to account for right-skewed distributions. Set to 0 or Inf for no capping. Note that at most 10k observations are sampled to calculate quartiles. Vectorized over v.

pd_n

Size of the data used for calculating partial dependence. The default is 500. For larger data (and w), pd_n rows are randomly sampled. Each variable specified by v uses the same subsample. Set to 0 to omit.

seed

Optional random seed (an integer) used for:

  • Partial dependence: select background data if n > pd_n.

  • Capping X: quartiles are selected based on 10k observations.

Value

A list (of class "EffectData") with a data.frame of statistics per feature. Use single bracket subsetting to select part of the output.

Methods (by class)

References

Friedman, Jerome H. 2001, Greedy Function Approximation: A Gradient Boosting Machine. Annals of Statistics 29 (5): 1189-1232. doi:10.1214/aos/1013203451.

See Also

feature_effects(), .pd(), ale().

Examples

fit <- lm(Sepal.Length ~ ., data = iris)
M <- partial_dependence(fit, v = "Species", data = iris)
M |> plot()

M2 <- partial_dependence(fit, v = colnames(iris)[-1], data = iris)
plot(M2, share_y = "all")

[Package effectplots version 0.1.0 Index]