Skip to contents

Predict values on data based on XGBoost model.

Usage

# S3 method for class 'xgboost'
predict(
  object,
  newdata,
  type = "response",
  base_margin = NULL,
  iteration_range = NULL,
  validate_features = TRUE,
  ...
)

Arguments

object

An XGBoost model object of class xgboost, as produced by function xgboost().

Note that there is also a lower-level predict.xgb.Booster() method for models of class xgb.Booster as produced by xgb.train(), which can also be used for xgboost class models as an alternative that performs fewer validations and post-processings.

newdata

Data on which to compute predictions from the model passed in object. Supported input classes are:

  • Data Frames (class data.frame from base R and subclasses like data.table).

  • Matrices (class matrix from base R).

  • Sparse matrices from package Matrix, either as class dgRMatrix (CSR) or dgCMatrix (CSC).

  • Sparse vectors from package Matrix, which will be interpreted as containing a single observation.

In the case of data frames, if there are any categorical features, they should be of class factor and should have the same levels as the factor columns of the data from which the model was constructed. Any columns with type other than factor will be interpreted as numeric.

If there are named columns and the model was fitted to data with named columns, they will be matched by name by default (see validate_features).

type

Type of prediction to make. Supported options are:

  • "response": will output model predictions on the scale of the response variable (e.g. probabilities of belonging to the last class in the case of binary classification). Result will be either a numeric vector with length matching to rows in newdata, or a numeric matrix with shape [nrows(newdata), nscores] (for objectives that produce more than one score per observation such as multi-class classification or multi-quantile regression).

  • "raw": will output the unprocessed boosting scores (e.g. log-odds in the case of objective binary:logistic). Same output shape and type as for "response".

  • "class": will output the class with the highest predicted probability, returned as a factor (only applicable to classification objectives) with length matching to rows in newdata.

  • "leaf": will output the terminal node indices of each observation across each tree, as an integer matrix of shape [nrows(newdata), ntrees], or as an integer array with an extra one or two dimensions, up to [nrows(newdata), ntrees, nscores, n_parallel_trees] for models that produce more than one score per tree and/or which have more than one parallel tree (e.g. random forests).

    Only applicable to tree-based boosters (not gblinear).

  • "contrib": will produce per-feature contribution estimates towards the model score for a given observation, based on SHAP values. The contribution values are on the scale of untransformed margin (e.g., for binary classification, the values are log-odds deviations from the baseline).

    Output will be a numeric matrix with shape [nrows, nfeatures+1], with the intercept being the last feature, or a numeric array with shape [nrows, nscores, nfeatures+1] if the model produces more than one score per observation.

  • "interaction": similar to "contrib", but computing SHAP values of contributions of interaction of each pair of features. Note that this operation might be rather expensive in terms of compute and memory.

    Since it quadratically depends on the number of features, it is recommended to perform selection of the most important features first.

    Output will be a numeric array of shape [nrows, nfeatures+1, nfeatures+1], or shape [nrows, nscores, nfeatures+1, nfeatures+1] (for objectives that produce more than one score per observation).

base_margin

Base margin used for boosting from existing model (raw score that gets added to all observations independently of the trees in the model).

If supplied, should be either a vector with length equal to the number of rows in newdata (for objectives which produces a single score per observation), or a matrix with number of rows matching to the number rows in newdata and number of columns matching to the number of scores estimated by the model (e.g. number of classes for multi-class classification).

iteration_range

Sequence of rounds/iterations from the model to use for prediction, specified by passing a two-dimensional vector with the start and end numbers in the sequence (same format as R's seq - i.e. base-1 indexing, and inclusive of both ends).

For example, passing c(1,20) will predict using the first twenty iterations, while passing c(1,1) will predict using only the first one.

If passing NULL, will either stop at the best iteration if the model used early stopping, or use all of the iterations (rounds) otherwise.

If passing "all", will use all of the rounds regardless of whether the model had early stopping or not.

Not applicable to gblinear booster.

validate_features

Validate that the feature names in the data match to the feature names in the column, and reorder them in the data otherwise.

If passing FALSE, it is assumed that the feature names and types are the same, and come in the same order as in the training data.

Be aware that this only applies to column names and not to factor levels in categorical columns.

Note that this check might add some sizable latency to the predictions, so it's recommended to disable it for performance-sensitive applications.

...

Not used.

Value

Either a numeric vector (for 1D outputs), numeric matrix (for 2D outputs), numeric array (for 3D and higher), or factor (for class predictions). See documentation for parameter type for details about what the output type and shape will be.

Examples

data("ToothGrowth")
y <- ToothGrowth$supp
x <- ToothGrowth[, -2L]
model <- xgboost(x, y, nthreads = 1L, nrounds = 3L, max_depth = 2L)
pred_prob <- predict(model, x[1:5, ], type = "response")
pred_raw <- predict(model, x[1:5, ], type = "raw")
pred_class <- predict(model, x[1:5, ], type = "class")

# Relationships between these
manual_probs <- 1 / (1 + exp(-pred_raw))
manual_class <- ifelse(manual_probs < 0.5, levels(y)[1], levels(y)[2])

# They should match up to numerical precision
round(pred_prob, 6) == round(manual_probs, 6)
pred_class == manual_class