Note
Go to the end to download the full example code
Demo for using and defining callback functions
New in version 1.3.0.
import argparse
import os
import tempfile
import numpy as np
from matplotlib import pyplot as plt
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
import xgboost as xgb
class Plotting(xgb.callback.TrainingCallback):
"""Plot evaluation result during training. Only for demonstration purpose as it's quite
slow to draw.
"""
def __init__(self, rounds):
self.fig = plt.figure()
self.ax = self.fig.add_subplot(111)
self.rounds = rounds
self.lines = {}
self.fig.show()
self.x = np.linspace(0, self.rounds, self.rounds)
plt.ion()
def _get_key(self, data, metric):
return f"{data}-{metric}"
def after_iteration(self, model, epoch, evals_log):
"""Update the plot."""
if not self.lines:
for data, metric in evals_log.items():
for metric_name, log in metric.items():
key = self._get_key(data, metric_name)
expanded = log + [0] * (self.rounds - len(log))
(self.lines[key],) = self.ax.plot(self.x, expanded, label=key)
self.ax.legend()
else:
# https://pythonspot.com/matplotlib-update-plot/
for data, metric in evals_log.items():
for metric_name, log in metric.items():
key = self._get_key(data, metric_name)
expanded = log + [0] * (self.rounds - len(log))
self.lines[key].set_ydata(expanded)
self.fig.canvas.draw()
# False to indicate training should not stop.
return False
def custom_callback():
"""Demo for defining a custom callback function that plots evaluation result during
training."""
X, y = load_breast_cancer(return_X_y=True)
X_train, X_valid, y_train, y_valid = train_test_split(X, y, random_state=0)
D_train = xgb.DMatrix(X_train, y_train)
D_valid = xgb.DMatrix(X_valid, y_valid)
num_boost_round = 100
plotting = Plotting(num_boost_round)
# Pass it to the `callbacks` parameter as a list.
xgb.train(
{
"objective": "binary:logistic",
"eval_metric": ["error", "rmse"],
"tree_method": "hist",
"device": "cuda",
},
D_train,
evals=[(D_train, "Train"), (D_valid, "Valid")],
num_boost_round=num_boost_round,
callbacks=[plotting],
)
def check_point_callback():
# only for demo, set a larger value (like 100) in practice as checkpointing is quite
# slow.
rounds = 2
def check(as_pickle):
for i in range(0, 10, rounds):
if i == 0:
continue
if as_pickle:
path = os.path.join(tmpdir, "model_" + str(i) + ".pkl")
else:
path = os.path.join(tmpdir, "model_" + str(i) + ".json")
assert os.path.exists(path)
X, y = load_breast_cancer(return_X_y=True)
m = xgb.DMatrix(X, y)
# Check point to a temporary directory for demo
with tempfile.TemporaryDirectory() as tmpdir:
# Use callback class from xgboost.callback
# Feel free to subclass/customize it to suit your need.
check_point = xgb.callback.TrainingCheckPoint(
directory=tmpdir, iterations=rounds, name="model"
)
xgb.train(
{"objective": "binary:logistic"},
m,
num_boost_round=10,
verbose_eval=False,
callbacks=[check_point],
)
check(False)
# This version of checkpoint saves everything including parameters and
# model. See: doc/tutorials/saving_model.rst
check_point = xgb.callback.TrainingCheckPoint(
directory=tmpdir, iterations=rounds, as_pickle=True, name="model"
)
xgb.train(
{"objective": "binary:logistic"},
m,
num_boost_round=10,
verbose_eval=False,
callbacks=[check_point],
)
check(True)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--plot", default=1, type=int)
args = parser.parse_args()
check_point_callback()
if args.plot:
custom_callback()
Total running time of the script: (0 minutes 0.000 seconds)