8 #ifndef XGBOOST_COMMON_SURVIVAL_UTIL_H_
9 #define XGBOOST_COMMON_SURVIVAL_UTIL_H_
47 constexpr
double kEps = 1e-12;
53 inline double Clip(
double x,
double x_min,
double x_max) {
63 template<
typename Distribution>
67 template<
typename Distribution>
85 .describe(
"Choice of distribution for the noise term in "
86 "Accelerated Failure Time model");
89 .describe(
"Scaling factor used to scale the distribution in "
90 "Accelerated Failure Time model");
95 template<
typename Distribution>
98 double Loss(
double y_lower,
double y_upper,
double y_pred,
double sigma) {
99 const double log_y_lower = log(y_lower);
100 const double log_y_upper = log(y_upper);
104 if (y_lower == y_upper) {
105 const double z = (log_y_lower - y_pred) / sigma;
106 const double pdf = Distribution::PDF(z);
108 cost = -log(fmax(pdf / (sigma * y_lower),
aft::kEps));
110 double z_u, z_l, cdf_u, cdf_l;
111 if (isinf(y_upper)) {
114 z_u = (log_y_upper - y_pred) / sigma;
115 cdf_u = Distribution::CDF(z_u);
117 if (y_lower <= 0.0) {
120 z_l = (log_y_lower - y_pred) / sigma;
121 cdf_l = Distribution::CDF(z_l);
124 cost = -log(fmax(cdf_u - cdf_l,
aft::kEps));
131 double Gradient(
double y_lower,
double y_upper,
double y_pred,
double sigma) {
132 const double log_y_lower = log(y_lower);
133 const double log_y_upper = log(y_upper);
134 double numerator, denominator, gradient;
138 if (y_lower == y_upper) {
139 const double z = (log_y_lower - y_pred) / sigma;
140 const double pdf = Distribution::PDF(z);
141 const double grad_pdf = Distribution::GradPDF(z);
143 numerator = grad_pdf;
144 denominator = sigma * pdf;
147 double z_u = 0.0, z_l = 0.0, pdf_u, pdf_l, cdf_u, cdf_l;
149 if (isinf(y_upper)) {
154 z_u = (log_y_upper - y_pred) / sigma;
155 pdf_u = Distribution::PDF(z_u);
156 cdf_u = Distribution::CDF(z_u);
158 if (y_lower <= 0.0) {
163 z_l = (log_y_lower - y_pred) / sigma;
164 pdf_l = Distribution::PDF(z_l);
165 cdf_l = Distribution::CDF(z_l);
167 z_sign = (z_u > 0 || z_l > 0);
168 numerator = pdf_u - pdf_l;
169 denominator = sigma * (cdf_u - cdf_l);
171 gradient = numerator / denominator;
172 if (denominator <
aft::kEps && (isnan(gradient) || isinf(gradient))) {
173 gradient = aft::GetLimitGradAtInfPred<Distribution>(censor_type, z_sign, sigma);
180 double Hessian(
double y_lower,
double y_upper,
double y_pred,
double sigma) {
181 const double log_y_lower = log(y_lower);
182 const double log_y_upper = log(y_upper);
183 double numerator, denominator, hessian;
187 if (y_lower == y_upper) {
188 const double z = (log_y_lower - y_pred) / sigma;
189 const double pdf = Distribution::PDF(z);
190 const double grad_pdf = Distribution::GradPDF(z);
191 const double hess_pdf = Distribution::HessPDF(z);
193 numerator = -(pdf * hess_pdf - grad_pdf * grad_pdf);
194 denominator = sigma * sigma * pdf * pdf;
197 double z_u = 0.0, z_l = 0.0, grad_pdf_u, grad_pdf_l, pdf_u, pdf_l, cdf_u, cdf_l;
199 if (isinf(y_upper)) {
205 z_u = (log_y_upper - y_pred) / sigma;
206 pdf_u = Distribution::PDF(z_u);
207 cdf_u = Distribution::CDF(z_u);
208 grad_pdf_u = Distribution::GradPDF(z_u);
210 if (y_lower <= 0.0) {
216 z_l = (log_y_lower - y_pred) / sigma;
217 pdf_l = Distribution::PDF(z_l);
218 cdf_l = Distribution::CDF(z_l);
219 grad_pdf_l = Distribution::GradPDF(z_l);
221 const double cdf_diff = cdf_u - cdf_l;
222 const double pdf_diff = pdf_u - pdf_l;
223 const double grad_diff = grad_pdf_u - grad_pdf_l;
224 const double sqrt_denominator = sigma * cdf_diff;
225 z_sign = (z_u > 0 || z_l > 0);
226 numerator = -(cdf_diff * grad_diff - pdf_diff * pdf_diff);
227 denominator = sqrt_denominator * sqrt_denominator;
229 hessian = numerator / denominator;
230 if (denominator <
aft::kEps && (isnan(hessian) || isinf(hessian))) {
231 hessian = aft::GetLimitHessAtInfPred<Distribution>(censor_type, z_sign, sigma);
246 switch (censor_type) {
256 return std::numeric_limits<double>::quiet_NaN();
262 switch (censor_type) {
264 return 1.0 / (sigma * sigma);
266 return sign ? (1.0 / (sigma * sigma)) :
kMinHessian;
268 return sign ?
kMinHessian : (1.0 / (sigma * sigma));
270 return 1.0 / (sigma * sigma);
272 return std::numeric_limits<double>::quiet_NaN();
278 switch (censor_type) {
280 return sign ? (-1.0 / sigma) : (1.0 / sigma);
282 return sign ? (-1.0 / sigma) : 0.0;
284 return sign ? 0.0 : (1.0 / sigma);
286 return sign ? (-1.0 / sigma) : (1.0 / sigma);
288 return std::numeric_limits<double>::quiet_NaN();
298 switch (censor_type) {
305 return std::numeric_limits<double>::quiet_NaN();
311 switch (censor_type) {
317 return sign ? 0.0 : (1.0 / sigma);
321 return std::numeric_limits<double>::quiet_NaN();
330 switch (censor_type) {
339 return std::numeric_limits<double>::quiet_NaN();
347 #endif // XGBOOST_COMMON_SURVIVAL_UTIL_H_