Visualization of the ensemble of trees as a single collective unit.
Usage
xgb.plot.multi.trees(
model,
features_keep = 5,
plot_width = NULL,
plot_height = NULL,
render = TRUE,
...
)
Arguments
- model
Object of class
xgb.Booster
. If it contains feature names (they can be set throughsetinfo()
, they will be used in the output from this function.- features_keep
Number of features to keep in each position of the multi trees, by default 5.
- plot_width, plot_height
Width and height of the graph in pixels. The values are passed to
DiagrammeR::render_graph()
.- render
Should the graph be rendered or not? The default is
TRUE
.- ...
Not used.
Some arguments that were part of this function in previous XGBoost versions are currently deprecated or have been renamed. If a deprecated or renamed argument is passed, will throw a warning (by default) and use its current equivalent instead. This warning will become an error if using the 'strict mode' option.
If some additional argument is passed that is neither a current function argument nor a deprecated or renamed argument, a warning or error will be thrown depending on the 'strict mode' option.
Important:
...
will be removed in a future version, and all the current deprecation warnings will become errors. Please use only arguments that form part of the function signature.
Value
Rendered graph object which is an htmlwidget of ' class grViz
. Similar to
"ggplot" objects, it needs to be printed when not running from the command
line.
Details
Note that this function does not work with models that were fitted to categorical data.
This function tries to capture the complexity of a gradient boosted tree model in a cohesive way by compressing an ensemble of trees into a single tree-graph representation. The goal is to improve the interpretability of a model generally seen as black box.
Note: this function is applicable to tree booster-based models only.
It takes advantage of the fact that the shape of a binary tree is only defined by its depth (therefore, in a boosting model, all trees have similar shape).
Moreover, the trees tend to reuse the same features.
The function projects each tree onto one, and keeps for each position the
features_keep
first features (based on the Gain per feature measure).
This function is inspired by this blog post: https://wellecks.wordpress.com/2015/02/21/peering-into-the-black-box-visualizing-lambdamart/
Examples
data(agaricus.train, package = "xgboost")
## Keep the number of threads to 2 for examples
nthread <- 2
data.table::setDTthreads(nthread)
model <- xgboost(
agaricus.train$data, factor(agaricus.train$label),
nrounds = 30,
verbosity = 0L,
nthreads = nthread,
max_depth = 15,
learning_rate = 1,
min_child_weight = 50
)
p <- xgb.plot.multi.trees(model, features_keep = 3)
print(p)
# Below is an example of how to save this plot to a file.
if (require("DiagrammeR") && require("DiagrammeRsvg") && require("rsvg")) {
fname <- file.path(tempdir(), "tree.pdf")
gr <- xgb.plot.multi.trees(model, features_keep = 3, render = FALSE)
export_graph(gr, fname, width = 1500, height = 600)
}