.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "python/survival-examples/aft_survival_demo.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_python_survival-examples_aft_survival_demo.py: Demo for survival analysis (regression). ======================================== Demo for survival analysis (regression). using Accelerated Failure Time (AFT) model. .. GENERATED FROM PYTHON SOURCE LINES 7-63 .. code-block:: Python import os import numpy as np import pandas as pd from sklearn.model_selection import ShuffleSplit import xgboost as xgb # The Veterans' Administration Lung Cancer Trial # The Statistical Analysis of Failure Time Data by Kalbfleisch J. and Prentice R (1980) CURRENT_DIR = os.path.dirname(__file__) df = pd.read_csv(os.path.join(CURRENT_DIR, '../data/veterans_lung_cancer.csv')) print('Training data:') print(df) # Split features and labels y_lower_bound = df['Survival_label_lower_bound'] y_upper_bound = df['Survival_label_upper_bound'] X = df.drop(['Survival_label_lower_bound', 'Survival_label_upper_bound'], axis=1) # Split data into training and validation sets rs = ShuffleSplit(n_splits=2, test_size=.7, random_state=0) train_index, valid_index = next(rs.split(X)) dtrain = xgb.DMatrix(X.values[train_index, :]) dtrain.set_float_info('label_lower_bound', y_lower_bound[train_index]) dtrain.set_float_info('label_upper_bound', y_upper_bound[train_index]) dvalid = xgb.DMatrix(X.values[valid_index, :]) dvalid.set_float_info('label_lower_bound', y_lower_bound[valid_index]) dvalid.set_float_info('label_upper_bound', y_upper_bound[valid_index]) # Train gradient boosted trees using AFT loss and metric params = {'verbosity': 0, 'objective': 'survival:aft', 'eval_metric': 'aft-nloglik', 'tree_method': 'hist', 'learning_rate': 0.05, 'aft_loss_distribution': 'normal', 'aft_loss_distribution_scale': 1.20, 'max_depth': 6, 'lambda': 0.01, 'alpha': 0.02} bst = xgb.train(params, dtrain, num_boost_round=10000, evals=[(dtrain, 'train'), (dvalid, 'valid')], early_stopping_rounds=50) # Run prediction on the validation set df = pd.DataFrame({'Label (lower bound)': y_lower_bound[valid_index], 'Label (upper bound)': y_upper_bound[valid_index], 'Predicted label': bst.predict(dvalid)}) print(df) # Show only data points with right-censored labels print(df[np.isinf(df['Label (upper bound)'])]) # Save trained model bst.save_model('aft_model.json') .. _sphx_glr_download_python_survival-examples_aft_survival_demo.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: aft_survival_demo.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: aft_survival_demo.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: aft_survival_demo.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_