Skip to contents

Visualization of the ensemble of trees as a single collective unit.

Usage

xgb.plot.multi.trees(
  model,
  features_keep = 5,
  plot_width = NULL,
  plot_height = NULL,
  render = TRUE,
  ...
)

Arguments

model

Object of class xgb.Booster. If it contains feature names (they can be set through setinfo(), they will be used in the output from this function.

features_keep

Number of features to keep in each position of the multi trees, by default 5.

plot_width, plot_height

Width and height of the graph in pixels. The values are passed to DiagrammeR::render_graph().

render

Should the graph be rendered or not? The default is TRUE.

...

Not used.

Some arguments that were part of this function in previous XGBoost versions are currently deprecated or have been renamed. If a deprecated or renamed argument is passed, will throw a warning (by default) and use its current equivalent instead. This warning will become an error if using the 'strict mode' option.

If some additional argument is passed that is neither a current function argument nor a deprecated or renamed argument, a warning or error will be thrown depending on the 'strict mode' option.

Important: ... will be removed in a future version, and all the current deprecation warnings will become errors. Please use only arguments that form part of the function signature.

Value

Rendered graph object which is an htmlwidget of ' class grViz. Similar to "ggplot" objects, it needs to be printed when not running from the command line.

Details

Note that this function does not work with models that were fitted to categorical data.

This function tries to capture the complexity of a gradient boosted tree model in a cohesive way by compressing an ensemble of trees into a single tree-graph representation. The goal is to improve the interpretability of a model generally seen as black box.

Note: this function is applicable to tree booster-based models only.

It takes advantage of the fact that the shape of a binary tree is only defined by its depth (therefore, in a boosting model, all trees have similar shape).

Moreover, the trees tend to reuse the same features.

The function projects each tree onto one, and keeps for each position the features_keep first features (based on the Gain per feature measure).

This function is inspired by this blog post: https://wellecks.wordpress.com/2015/02/21/peering-into-the-black-box-visualizing-lambdamart/

Examples


data(agaricus.train, package = "xgboost")

## Keep the number of threads to 2 for examples
nthread <- 2
data.table::setDTthreads(nthread)

model <- xgboost(
  agaricus.train$data, factor(agaricus.train$label),
  nrounds = 30,
  verbosity = 0L,
  nthreads = nthread,
  max_depth = 15,
  learning_rate = 1,
  min_child_weight = 50
)

p <- xgb.plot.multi.trees(model, features_keep = 3)
print(p)

# Below is an example of how to save this plot to a file.
if (require("DiagrammeR") && require("DiagrammeRsvg") && require("rsvg")) {
  fname <- file.path(tempdir(), "tree.pdf")
  gr <- xgb.plot.multi.trees(model, features_keep = 3, render = FALSE)
  export_graph(gr, fname, width = 1500, height = 600)
}