Note
Go to the end to download the full example code.
Example of training survival model with Dask on CPU
import os
import dask.array as da
import dask.dataframe as dd
from dask.distributed import Client, LocalCluster
from xgboost import dask as dxgb
from xgboost.dask import DaskDMatrix
def main(client: Client) -> da.Array:
# Load an example survival data from CSV into a Dask data frame.
# The Veterans' Administration Lung Cancer Trial
# The Statistical Analysis of Failure Time Data by Kalbfleisch J. and Prentice R (1980)
CURRENT_DIR = os.path.dirname(__file__)
df = dd.read_csv(
os.path.join(CURRENT_DIR, os.pardir, "data", "veterans_lung_cancer.csv")
)
# DaskDMatrix acts like normal DMatrix, works as a proxy for local
# DMatrix scatter around workers.
# For AFT survival, you'd need to extract the lower and upper bounds for the label
# and pass them as arguments to DaskDMatrix.
y_lower_bound = df["Survival_label_lower_bound"]
y_upper_bound = df["Survival_label_upper_bound"]
X = df.drop(["Survival_label_lower_bound", "Survival_label_upper_bound"], axis=1)
dtrain = DaskDMatrix(
client, X, label_lower_bound=y_lower_bound, label_upper_bound=y_upper_bound
)
# Use train method from xgboost.dask instead of xgboost. This
# distributed version of train returns a dictionary containing the
# resulting booster and evaluation history obtained from
# evaluation metrics.
params = {
"verbosity": 1,
"objective": "survival:aft",
"eval_metric": "aft-nloglik",
"learning_rate": 0.05,
"aft_loss_distribution_scale": 1.20,
"aft_loss_distribution": "normal",
"max_depth": 6,
"lambda": 0.01,
"alpha": 0.02,
}
output = dxgb.train(
client, params, dtrain, num_boost_round=100, evals=[(dtrain, "train")]
)
bst = output["booster"]
history = output["history"]
# you can pass output directly into `predict` too.
prediction = dxgb.predict(client, bst, dtrain)
print("Evaluation history: ", history)
# Uncomment the following line to save the model to the disk
# bst.save_model('survival_model.json')
return prediction
if __name__ == "__main__":
# or use other clusters for scaling
with LocalCluster(n_workers=7, threads_per_worker=4) as cluster:
with Client(cluster) as client:
main(client)