xgboost
survival_util.h
Go to the documentation of this file.
1 
8 #ifndef XGBOOST_COMMON_SURVIVAL_UTIL_H_
9 #define XGBOOST_COMMON_SURVIVAL_UTIL_H_
10 
11 /*
12  * For the derivation of the loss, gradient, and hessian for the Accelerated Failure Time model,
13  * refer to the paper "Survival regression with accelerated failure time model in XGBoost"
14  * at https://arxiv.org/abs/2006.04920.
15  */
16 
17 #include <xgboost/parameter.h>
18 #include <memory>
19 #include <algorithm>
20 #include <limits>
22 
24 
25 namespace xgboost {
26 namespace common {
27 
28 #ifndef __CUDACC__
29 
30 using std::log;
31 using std::fmax;
32 
33 #endif // __CUDACC__
34 
35 enum class CensoringType : uint8_t {
37 };
38 
39 namespace aft {
40 
41 // Allowable range for gradient and hessian. Used for regularization
42 constexpr double kMinGradient = -15.0;
43 constexpr double kMaxGradient = 15.0;
44 constexpr double kMinHessian = 1e-16; // Ensure that no data point gets zero hessian
45 constexpr double kMaxHessian = 15.0;
46 
47 constexpr double kEps = 1e-12; // A denomitor in a fraction should not be too small
48 
49 // Clip (limit) x to fit range [x_min, x_max].
50 // If x < x_min, return x_min; if x > x_max, return x_max; if x_min <= x <= x_max, return x.
51 // This function assumes x_min < x_max; behavior is undefined if this assumption does not hold.
53 inline double Clip(double x, double x_min, double x_max) {
54  if (x < x_min) {
55  return x_min;
56  }
57  if (x > x_max) {
58  return x_max;
59  }
60  return x;
61 }
62 
63 template<typename Distribution>
64 XGBOOST_DEVICE inline double
65 GetLimitGradAtInfPred(CensoringType censor_type, bool sign, double sigma);
66 
67 template<typename Distribution>
68 XGBOOST_DEVICE inline double
69 GetLimitHessAtInfPred(CensoringType censor_type, bool sign, double sigma);
70 
71 } // namespace aft
72 
74 struct AFTParam : public XGBoostParameter<AFTParam> {
80  DMLC_DECLARE_FIELD(aft_loss_distribution)
82  .add_enum("normal", ProbabilityDistributionType::kNormal)
83  .add_enum("logistic", ProbabilityDistributionType::kLogistic)
84  .add_enum("extreme", ProbabilityDistributionType::kExtreme)
85  .describe("Choice of distribution for the noise term in "
86  "Accelerated Failure Time model");
87  DMLC_DECLARE_FIELD(aft_loss_distribution_scale)
88  .set_default(1.0f)
89  .describe("Scaling factor used to scale the distribution in "
90  "Accelerated Failure Time model");
91  }
92 };
93 
95 template<typename Distribution>
96 struct AFTLoss {
97  XGBOOST_DEVICE inline static
98  double Loss(double y_lower, double y_upper, double y_pred, double sigma) {
99  const double log_y_lower = log(y_lower);
100  const double log_y_upper = log(y_upper);
101 
102  double cost;
103 
104  if (y_lower == y_upper) { // uncensored
105  const double z = (log_y_lower - y_pred) / sigma;
106  const double pdf = Distribution::PDF(z);
107  // Regularize the denominator with eps, to avoid INF or NAN
108  cost = -log(fmax(pdf / (sigma * y_lower), aft::kEps));
109  } else { // censored; now check what type of censorship we have
110  double z_u, z_l, cdf_u, cdf_l;
111  if (isinf(y_upper)) { // right-censored
112  cdf_u = 1;
113  } else { // left-censored or interval-censored
114  z_u = (log_y_upper - y_pred) / sigma;
115  cdf_u = Distribution::CDF(z_u);
116  }
117  if (y_lower <= 0.0) { // left-censored
118  cdf_l = 0;
119  } else { // right-censored or interval-censored
120  z_l = (log_y_lower - y_pred) / sigma;
121  cdf_l = Distribution::CDF(z_l);
122  }
123  // Regularize the denominator with eps, to avoid INF or NAN
124  cost = -log(fmax(cdf_u - cdf_l, aft::kEps));
125  }
126 
127  return cost;
128  }
129 
130  XGBOOST_DEVICE inline static
131  double Gradient(double y_lower, double y_upper, double y_pred, double sigma) {
132  const double log_y_lower = log(y_lower);
133  const double log_y_upper = log(y_upper);
134  double numerator, denominator, gradient; // numerator and denominator of gradient
135  CensoringType censor_type;
136  bool z_sign; // sign of z-score
137 
138  if (y_lower == y_upper) { // uncensored
139  const double z = (log_y_lower - y_pred) / sigma;
140  const double pdf = Distribution::PDF(z);
141  const double grad_pdf = Distribution::GradPDF(z);
142  censor_type = CensoringType::kUncensored;
143  numerator = grad_pdf;
144  denominator = sigma * pdf;
145  z_sign = (z > 0);
146  } else { // censored; now check what type of censorship we have
147  double z_u = 0.0, z_l = 0.0, pdf_u, pdf_l, cdf_u, cdf_l;
148  censor_type = CensoringType::kIntervalCensored;
149  if (isinf(y_upper)) { // right-censored
150  pdf_u = 0;
151  cdf_u = 1;
152  censor_type = CensoringType::kRightCensored;
153  } else { // interval-censored or left-censored
154  z_u = (log_y_upper - y_pred) / sigma;
155  pdf_u = Distribution::PDF(z_u);
156  cdf_u = Distribution::CDF(z_u);
157  }
158  if (y_lower <= 0.0) { // left-censored
159  pdf_l = 0;
160  cdf_l = 0;
161  censor_type = CensoringType::kLeftCensored;
162  } else { // interval-censored or right-censored
163  z_l = (log_y_lower - y_pred) / sigma;
164  pdf_l = Distribution::PDF(z_l);
165  cdf_l = Distribution::CDF(z_l);
166  }
167  z_sign = (z_u > 0 || z_l > 0);
168  numerator = pdf_u - pdf_l;
169  denominator = sigma * (cdf_u - cdf_l);
170  }
171  gradient = numerator / denominator;
172  if (denominator < aft::kEps && (isnan(gradient) || isinf(gradient))) {
173  gradient = aft::GetLimitGradAtInfPred<Distribution>(censor_type, z_sign, sigma);
174  }
175 
176  return aft::Clip(gradient, aft::kMinGradient, aft::kMaxGradient);
177  }
178 
179  XGBOOST_DEVICE inline static
180  double Hessian(double y_lower, double y_upper, double y_pred, double sigma) {
181  const double log_y_lower = log(y_lower);
182  const double log_y_upper = log(y_upper);
183  double numerator, denominator, hessian; // numerator and denominator of hessian
184  CensoringType censor_type;
185  bool z_sign; // sign of z-score
186 
187  if (y_lower == y_upper) { // uncensored
188  const double z = (log_y_lower - y_pred) / sigma;
189  const double pdf = Distribution::PDF(z);
190  const double grad_pdf = Distribution::GradPDF(z);
191  const double hess_pdf = Distribution::HessPDF(z);
192  censor_type = CensoringType::kUncensored;
193  numerator = -(pdf * hess_pdf - grad_pdf * grad_pdf);
194  denominator = sigma * sigma * pdf * pdf;
195  z_sign = (z > 0);
196  } else { // censored; now check what type of censorship we have
197  double z_u = 0.0, z_l = 0.0, grad_pdf_u, grad_pdf_l, pdf_u, pdf_l, cdf_u, cdf_l;
198  censor_type = CensoringType::kIntervalCensored;
199  if (isinf(y_upper)) { // right-censored
200  pdf_u = 0;
201  cdf_u = 1;
202  grad_pdf_u = 0;
203  censor_type = CensoringType::kRightCensored;
204  } else { // interval-censored or left-censored
205  z_u = (log_y_upper - y_pred) / sigma;
206  pdf_u = Distribution::PDF(z_u);
207  cdf_u = Distribution::CDF(z_u);
208  grad_pdf_u = Distribution::GradPDF(z_u);
209  }
210  if (y_lower <= 0.0) { // left-censored
211  pdf_l = 0;
212  cdf_l = 0;
213  grad_pdf_l = 0;
214  censor_type = CensoringType::kLeftCensored;
215  } else { // interval-censored or right-censored
216  z_l = (log_y_lower - y_pred) / sigma;
217  pdf_l = Distribution::PDF(z_l);
218  cdf_l = Distribution::CDF(z_l);
219  grad_pdf_l = Distribution::GradPDF(z_l);
220  }
221  const double cdf_diff = cdf_u - cdf_l;
222  const double pdf_diff = pdf_u - pdf_l;
223  const double grad_diff = grad_pdf_u - grad_pdf_l;
224  const double sqrt_denominator = sigma * cdf_diff;
225  z_sign = (z_u > 0 || z_l > 0);
226  numerator = -(cdf_diff * grad_diff - pdf_diff * pdf_diff);
227  denominator = sqrt_denominator * sqrt_denominator;
228  }
229  hessian = numerator / denominator;
230  if (denominator < aft::kEps && (isnan(hessian) || isinf(hessian))) {
231  hessian = aft::GetLimitHessAtInfPred<Distribution>(censor_type, z_sign, sigma);
232  }
233 
234  return aft::Clip(hessian, aft::kMinHessian, aft::kMaxHessian);
235  }
236 };
237 
238 namespace aft {
239 
240 template <>
241 XGBOOST_DEVICE inline double
242 GetLimitGradAtInfPred<NormalDistribution>(CensoringType censor_type, bool sign, double sigma) {
243  switch (censor_type) {
245  return sign ? kMinGradient : kMaxGradient;
247  return sign ? kMinGradient : 0.0;
249  return sign ? 0.0 : kMaxGradient;
251  return sign ? kMinGradient : kMaxGradient;
252  }
253  return std::numeric_limits<double>::quiet_NaN();
254 }
255 
256 template <>
257 XGBOOST_DEVICE inline double
258 GetLimitHessAtInfPred<NormalDistribution>(CensoringType censor_type, bool sign, double sigma) {
259  switch (censor_type) {
261  return 1.0 / (sigma * sigma);
263  return sign ? (1.0 / (sigma * sigma)) : kMinHessian;
265  return sign ? kMinHessian : (1.0 / (sigma * sigma));
267  return 1.0 / (sigma * sigma);
268  }
269  return std::numeric_limits<double>::quiet_NaN();
270 }
271 
272 template <>
273 XGBOOST_DEVICE inline double
274 GetLimitGradAtInfPred<LogisticDistribution>(CensoringType censor_type, bool sign, double sigma) {
275  switch (censor_type) {
277  return sign ? (-1.0 / sigma) : (1.0 / sigma);
279  return sign ? (-1.0 / sigma) : 0.0;
281  return sign ? 0.0 : (1.0 / sigma);
283  return sign ? (-1.0 / sigma) : (1.0 / sigma);
284  }
285  return std::numeric_limits<double>::quiet_NaN();
286 }
287 
288 template <>
289 XGBOOST_DEVICE inline double
290 GetLimitHessAtInfPred<LogisticDistribution>(CensoringType censor_type, bool sign, double sigma) {
291  switch (censor_type) {
296  return kMinHessian;
297  }
298  return std::numeric_limits<double>::quiet_NaN();
299 }
300 
301 template <>
302 XGBOOST_DEVICE inline double
303 GetLimitGradAtInfPred<ExtremeDistribution>(CensoringType censor_type, bool sign, double sigma) {
304  switch (censor_type) {
306  return sign ? kMinGradient : (1.0 / sigma);
308  return sign ? kMinGradient : 0.0;
310  return sign ? 0.0 : (1.0 / sigma);
312  return sign ? kMinGradient : (1.0 / sigma);
313  }
314  return std::numeric_limits<double>::quiet_NaN();
315 }
316 
317 template <>
318 XGBOOST_DEVICE inline double
319 GetLimitHessAtInfPred<ExtremeDistribution>(CensoringType censor_type, bool sign, double sigma) {
320  switch (censor_type) {
323  return sign ? kMaxHessian : kMinHessian;
325  return kMinHessian;
327  return sign ? kMaxHessian : kMinHessian;
328  }
329  return std::numeric_limits<double>::quiet_NaN();
330 }
331 
332 } // namespace aft
333 
334 } // namespace common
335 } // namespace xgboost
336 
337 #endif // XGBOOST_COMMON_SURVIVAL_UTIL_H_
The AFT loss function.
Definition: survival_util.h:96
static XGBOOST_DEVICE double Loss(double y_lower, double y_upper, double y_pred, double sigma)
Definition: survival_util.h:98
static XGBOOST_DEVICE double Gradient(double y_lower, double y_upper, double y_pred, double sigma)
Definition: survival_util.h:131
Definition: parameter.h:84
XGBOOST_DEVICE double GetLimitHessAtInfPred< LogisticDistribution >(CensoringType censor_type, bool sign, double sigma)
Definition: survival_util.h:290
constexpr double kMinHessian
Definition: survival_util.h:44
constexpr double kEps
Definition: survival_util.h:47
XGBOOST_DEVICE double GetLimitHessAtInfPred< ExtremeDistribution >(CensoringType censor_type, bool sign, double sigma)
Definition: survival_util.h:319
float aft_loss_distribution_scale
Scaling factor to be applied to the distribution.
Definition: survival_util.h:78
constexpr double kMaxGradient
Definition: survival_util.h:43
XGBOOST_DEVICE double GetLimitGradAtInfPred< LogisticDistribution >(CensoringType censor_type, bool sign, double sigma)
Definition: survival_util.h:274
ProbabilityDistributionType
Enum encoding possible choices of probability distribution.
Definition: probability_distribution.h:31
DMLC_DECLARE_PARAMETER(AFTParam)
Definition: survival_util.h:79
Parameter structure for AFT loss and metric.
Definition: survival_util.h:74
XGBOOST_DEVICE double GetLimitGradAtInfPred(CensoringType censor_type, bool sign, double sigma)
DECLARE_FIELD_ENUM_CLASS(xgboost::common::ProbabilityDistributionType)
constexpr double kMaxHessian
Definition: survival_util.h:45
XGBOOST_DEVICE double GetLimitHessAtInfPred(CensoringType censor_type, bool sign, double sigma)
XGBOOST_DEVICE double GetLimitHessAtInfPred< NormalDistribution >(CensoringType censor_type, bool sign, double sigma)
Definition: survival_util.h:258
XGBOOST_DEVICE double Clip(double x, double x_min, double x_max)
Definition: survival_util.h:53
constexpr double kMinGradient
Definition: survival_util.h:42
#define XGBOOST_DEVICE
Tag function as usable by device.
Definition: base.h:84
CensoringType
Definition: survival_util.h:35
namespace of xgboost
Definition: base.h:102
Implementation of a few useful probability distributions.
XGBOOST_DEVICE double GetLimitGradAtInfPred< NormalDistribution >(CensoringType censor_type, bool sign, double sigma)
Definition: survival_util.h:242
XGBOOST_DEVICE double GetLimitGradAtInfPred< ExtremeDistribution >(CensoringType censor_type, bool sign, double sigma)
Definition: survival_util.h:303
static XGBOOST_DEVICE double Hessian(double y_lower, double y_upper, double y_pred, double sigma)
Definition: survival_util.h:180
macro for using C++11 enum class as DMLC parameter
ProbabilityDistributionType aft_loss_distribution
Choice of probability distribution for the noise term in AFT.
Definition: survival_util.h:76