xgboost
survival_util.h
Go to the documentation of this file.
1
8 #ifndef XGBOOST_COMMON_SURVIVAL_UTIL_H_
9 #define XGBOOST_COMMON_SURVIVAL_UTIL_H_
10
11 /*
12  * For the derivation of the loss, gradient, and hessian for the Accelerated Failure Time model,
13  * refer to the paper "Survival regression with accelerated failure time model in XGBoost"
14  * at https://arxiv.org/abs/2006.04920.
15  */
16
17 #include <xgboost/parameter.h>
18 #include <memory>
19 #include <algorithm>
20 #include <limits>
22
24
25 namespace xgboost {
26 namespace common {
27
28 #ifndef __CUDACC__
29
30 using std::log;
31 using std::fmax;
32
33 #endif // __CUDACC__
34
35 enum class CensoringType : uint8_t {
37 };
38
39 namespace aft {
40
41 // Allowable range for gradient and hessian. Used for regularization
42 constexpr double kMinGradient = -15.0;
43 constexpr double kMaxGradient = 15.0;
44 constexpr double kMinHessian = 1e-16; // Ensure that no data point gets zero hessian
45 constexpr double kMaxHessian = 15.0;
46
47 constexpr double kEps = 1e-12; // A denominator in a fraction should not be too small
48
49 // Clip (limit) x to fit range [x_min, x_max].
50 // If x < x_min, return x_min; if x > x_max, return x_max; if x_min <= x <= x_max, return x.
51 // This function assumes x_min < x_max; behavior is undefined if this assumption does not hold.
53 inline double Clip(double x, double x_min, double x_max) {
54  if (x < x_min) {
55  return x_min;
56  }
57  if (x > x_max) {
58  return x_max;
59  }
60  return x;
61 }
62
63 template<typename Distribution>
64 XGBOOST_DEVICE inline double
65 GetLimitGradAtInfPred(CensoringType censor_type, bool sign, double sigma);
66
67 template<typename Distribution>
68 XGBOOST_DEVICE inline double
69 GetLimitHessAtInfPred(CensoringType censor_type, bool sign, double sigma);
70
71 } // namespace aft
72
74 struct AFTParam : public XGBoostParameter<AFTParam> {
80  DMLC_DECLARE_FIELD(aft_loss_distribution)
85  .describe("Choice of distribution for the noise term in "
86  "Accelerated Failure Time model");
87  DMLC_DECLARE_FIELD(aft_loss_distribution_scale)
88  .set_default(1.0f)
89  .describe("Scaling factor used to scale the distribution in "
90  "Accelerated Failure Time model");
91  }
92 };
93
95 template<typename Distribution>
96 struct AFTLoss {
97  XGBOOST_DEVICE inline static
98  double Loss(double y_lower, double y_upper, double y_pred, double sigma) {
99  const double log_y_lower = log(y_lower);
100  const double log_y_upper = log(y_upper);
101
102  double cost;
103
104  if (y_lower == y_upper) { // uncensored
105  const double z = (log_y_lower - y_pred) / sigma;
106  const double pdf = Distribution::PDF(z);
107  // Regularize the denominator with eps, to avoid INF or NAN
108  cost = -log(fmax(pdf / (sigma * y_lower), aft::kEps));
109  } else { // censored; now check what type of censorship we have
110  double z_u, z_l, cdf_u, cdf_l;
111  if (isinf(y_upper)) { // right-censored
112  cdf_u = 1;
113  } else { // left-censored or interval-censored
114  z_u = (log_y_upper - y_pred) / sigma;
115  cdf_u = Distribution::CDF(z_u);
116  }
117  if (y_lower <= 0.0) { // left-censored
118  cdf_l = 0;
119  } else { // right-censored or interval-censored
120  z_l = (log_y_lower - y_pred) / sigma;
121  cdf_l = Distribution::CDF(z_l);
122  }
123  // Regularize the denominator with eps, to avoid INF or NAN
124  cost = -log(fmax(cdf_u - cdf_l, aft::kEps));
125  }
126
127  return cost;
128  }
129
130  XGBOOST_DEVICE inline static
131  double Gradient(double y_lower, double y_upper, double y_pred, double sigma) {
132  const double log_y_lower = log(y_lower);
133  const double log_y_upper = log(y_upper);
135  CensoringType censor_type;
136  bool z_sign; // sign of z-score
137
138  if (y_lower == y_upper) { // uncensored
139  const double z = (log_y_lower - y_pred) / sigma;
140  const double pdf = Distribution::PDF(z);
142  censor_type = CensoringType::kUncensored;
144  denominator = sigma * pdf;
145  z_sign = (z > 0);
146  } else { // censored; now check what type of censorship we have
147  double z_u = 0.0, z_l = 0.0, pdf_u, pdf_l, cdf_u, cdf_l;
148  censor_type = CensoringType::kIntervalCensored;
149  if (isinf(y_upper)) { // right-censored
150  pdf_u = 0;
151  cdf_u = 1;
152  censor_type = CensoringType::kRightCensored;
153  } else { // interval-censored or left-censored
154  z_u = (log_y_upper - y_pred) / sigma;
155  pdf_u = Distribution::PDF(z_u);
156  cdf_u = Distribution::CDF(z_u);
157  }
158  if (y_lower <= 0.0) { // left-censored
159  pdf_l = 0;
160  cdf_l = 0;
161  censor_type = CensoringType::kLeftCensored;
162  } else { // interval-censored or right-censored
163  z_l = (log_y_lower - y_pred) / sigma;
164  pdf_l = Distribution::PDF(z_l);
165  cdf_l = Distribution::CDF(z_l);
166  }
167  z_sign = (z_u > 0 || z_l > 0);
168  numerator = pdf_u - pdf_l;
169  denominator = sigma * (cdf_u - cdf_l);
170  }
171  gradient = numerator / denominator;
174  }
175
177  }
178
179  XGBOOST_DEVICE inline static
180  double Hessian(double y_lower, double y_upper, double y_pred, double sigma) {
181  const double log_y_lower = log(y_lower);
182  const double log_y_upper = log(y_upper);
183  double numerator, denominator, hessian; // numerator and denominator of hessian
184  CensoringType censor_type;
185  bool z_sign; // sign of z-score
186
187  if (y_lower == y_upper) { // uncensored
188  const double z = (log_y_lower - y_pred) / sigma;
189  const double pdf = Distribution::PDF(z);
191  const double hess_pdf = Distribution::HessPDF(z);
192  censor_type = CensoringType::kUncensored;
194  denominator = sigma * sigma * pdf * pdf;
195  z_sign = (z > 0);
196  } else { // censored; now check what type of censorship we have
197  double z_u = 0.0, z_l = 0.0, grad_pdf_u, grad_pdf_l, pdf_u, pdf_l, cdf_u, cdf_l;
198  censor_type = CensoringType::kIntervalCensored;
199  if (isinf(y_upper)) { // right-censored
200  pdf_u = 0;
201  cdf_u = 1;
203  censor_type = CensoringType::kRightCensored;
204  } else { // interval-censored or left-censored
205  z_u = (log_y_upper - y_pred) / sigma;
206  pdf_u = Distribution::PDF(z_u);
207  cdf_u = Distribution::CDF(z_u);
209  }
210  if (y_lower <= 0.0) { // left-censored
211  pdf_l = 0;
212  cdf_l = 0;
214  censor_type = CensoringType::kLeftCensored;
215  } else { // interval-censored or right-censored
216  z_l = (log_y_lower - y_pred) / sigma;
217  pdf_l = Distribution::PDF(z_l);
218  cdf_l = Distribution::CDF(z_l);
220  }
221  const double cdf_diff = cdf_u - cdf_l;
222  const double pdf_diff = pdf_u - pdf_l;
224  const double sqrt_denominator = sigma * cdf_diff;
225  z_sign = (z_u > 0 || z_l > 0);
226  numerator = -(cdf_diff * grad_diff - pdf_diff * pdf_diff);
227  denominator = sqrt_denominator * sqrt_denominator;
228  }
229  hessian = numerator / denominator;
230  if (denominator < aft::kEps && (isnan(hessian) || isinf(hessian))) {
231  hessian = aft::GetLimitHessAtInfPred<Distribution>(censor_type, z_sign, sigma);
232  }
233
234  return aft::Clip(hessian, aft::kMinHessian, aft::kMaxHessian);
235  }
236 };
237
238 namespace aft {
239
240 template <>
241 XGBOOST_DEVICE inline double
242 GetLimitGradAtInfPred<NormalDistribution>(CensoringType censor_type, bool sign, double sigma) {
243  // Remove unused parameter compiler warning.
244  (void) sigma;
245
246  switch (censor_type) {
250  return sign ? kMinGradient : 0.0;
252  return sign ? 0.0 : kMaxGradient;
255  }
256  return std::numeric_limits<double>::quiet_NaN();
257 }
258
259 template <>
260 XGBOOST_DEVICE inline double
261 GetLimitHessAtInfPred<NormalDistribution>(CensoringType censor_type, bool sign, double sigma) {
262  switch (censor_type) {
264  return 1.0 / (sigma * sigma);
266  return sign ? (1.0 / (sigma * sigma)) : kMinHessian;
268  return sign ? kMinHessian : (1.0 / (sigma * sigma));
270  return 1.0 / (sigma * sigma);
271  }
272  return std::numeric_limits<double>::quiet_NaN();
273 }
274
275 template <>
276 XGBOOST_DEVICE inline double
277 GetLimitGradAtInfPred<LogisticDistribution>(CensoringType censor_type, bool sign, double sigma) {
278  switch (censor_type) {
280  return sign ? (-1.0 / sigma) : (1.0 / sigma);
282  return sign ? (-1.0 / sigma) : 0.0;
284  return sign ? 0.0 : (1.0 / sigma);
286  return sign ? (-1.0 / sigma) : (1.0 / sigma);
287  }
288  return std::numeric_limits<double>::quiet_NaN();
289 }
290
291 template <>
292 XGBOOST_DEVICE inline double
293 GetLimitHessAtInfPred<LogisticDistribution>(CensoringType censor_type, bool sign, double sigma) {
294  // Remove unused parameter compiler warning.
295  (void) sign;
296  (void) sigma;
297
298  switch (censor_type) {
303  return kMinHessian;
304  }
305  return std::numeric_limits<double>::quiet_NaN();
306 }
307
308 template <>
309 XGBOOST_DEVICE inline double
310 GetLimitGradAtInfPred<ExtremeDistribution>(CensoringType censor_type, bool sign, double sigma) {
311  switch (censor_type) {
313  return sign ? kMinGradient : (1.0 / sigma);
315  return sign ? kMinGradient : 0.0;
317  return sign ? 0.0 : (1.0 / sigma);
319  return sign ? kMinGradient : (1.0 / sigma);
320  }
321  return std::numeric_limits<double>::quiet_NaN();
322 }
323
324 template <>
325 XGBOOST_DEVICE inline double
326 GetLimitHessAtInfPred<ExtremeDistribution>(CensoringType censor_type, bool sign, double sigma) {
327  // Remove unused parameter compiler warning.
328  (void) sigma;
329
330  switch (censor_type) {
333  return sign ? kMaxHessian : kMinHessian;
335  return kMinHessian;
337  return sign ? kMaxHessian : kMinHessian;
338  }
339  return std::numeric_limits<double>::quiet_NaN();
340 }
341
342 } // namespace aft
343
344 } // namespace common
345 } // namespace xgboost
346
347 #endif // XGBOOST_COMMON_SURVIVAL_UTIL_H_
Definition: survival_util.h:43
xgboost::common::CensoringType::kLeftCensored
@ kLeftCensored
static XGBOOST_DEVICE double Gradient(double y_lower, double y_upper, double y_pred, double sigma)
Definition: survival_util.h:131
xgboost::common::CensoringType::kIntervalCensored
@ kIntervalCensored
xgboost::common::aft::kMaxHessian
constexpr double kMaxHessian
Definition: survival_util.h:45
xgboost::common::ProbabilityDistributionType
ProbabilityDistributionType
Enum encoding possible choices of probability distribution.
Definition: probability_distribution.h:31
xgboost::common::AFTLoss
The AFT loss function.
Definition: survival_util.h:96
xgboost::common::aft::GetLimitHessAtInfPred< NormalDistribution >
XGBOOST_DEVICE double GetLimitHessAtInfPred< NormalDistribution >(CensoringType censor_type, bool sign, double sigma)
Definition: survival_util.h:261
XGBOOST_DEVICE double GetLimitGradAtInfPred< NormalDistribution >(CensoringType censor_type, bool sign, double sigma)
Definition: survival_util.h:242
xgboost::common::AFTParam::aft_loss_distribution
ProbabilityDistributionType aft_loss_distribution
Choice of probability distribution for the noise term in AFT.
Definition: survival_util.h:76
parameter.h
macro for using C++11 enum class as DMLC parameter
xgboost::common::AFTParam::aft_loss_distribution_scale
float aft_loss_distribution_scale
Scaling factor to be applied to the distribution.
Definition: survival_util.h:78
xgboost::common::aft::GetLimitHessAtInfPred
XGBOOST_DEVICE double GetLimitHessAtInfPred(CensoringType censor_type, bool sign, double sigma)
xgboost::common::aft::Clip
XGBOOST_DEVICE double Clip(double x, double x_min, double x_max)
Definition: survival_util.h:53
xgboost::common::CensoringType::kRightCensored
@ kRightCensored
Definition: survival_util.h:42
xgboost::common::AFTParam
Parameter structure for AFT loss and metric.
Definition: survival_util.h:74
xgboost::common::CensoringType::kUncensored
@ kUncensored
xgboost::common::AFTParam::DMLC_DECLARE_PARAMETER
DMLC_DECLARE_PARAMETER(AFTParam)
Definition: survival_util.h:79
probability_distribution.h
Implementation of a few useful probability distributions.
xgboost::XGBoostParameter
Definition: parameter.h:84
XGBOOST_DEVICE double GetLimitGradAtInfPred< ExtremeDistribution >(CensoringType censor_type, bool sign, double sigma)
Definition: survival_util.h:310
xgboost::common::AFTLoss::Loss
static XGBOOST_DEVICE double Loss(double y_lower, double y_upper, double y_pred, double sigma)
Definition: survival_util.h:98
xgboost::common::AFTLoss::Hessian
static XGBOOST_DEVICE double Hessian(double y_lower, double y_upper, double y_pred, double sigma)
Definition: survival_util.h:180
xgboost::common::aft::kMinHessian
constexpr double kMinHessian
Definition: survival_util.h:44
xgboost::common::aft::GetLimitHessAtInfPred< LogisticDistribution >
XGBOOST_DEVICE double GetLimitHessAtInfPred< LogisticDistribution >(CensoringType censor_type, bool sign, double sigma)
Definition: survival_util.h:293
xgboost::common::aft::GetLimitHessAtInfPred< ExtremeDistribution >
XGBOOST_DEVICE double GetLimitHessAtInfPred< ExtremeDistribution >(CensoringType censor_type, bool sign, double sigma)
Definition: survival_util.h:326
DECLARE_FIELD_ENUM_CLASS
DECLARE_FIELD_ENUM_CLASS(xgboost::common::ProbabilityDistributionType)
XGBOOST_DEVICE double GetLimitGradAtInfPred< LogisticDistribution >(CensoringType censor_type, bool sign, double sigma)
Definition: survival_util.h:277
xgboost::common::aft::kEps
constexpr double kEps
Definition: survival_util.h:47