xgboost
survival_util.h
Go to the documentation of this file.
1 
8 #ifndef XGBOOST_COMMON_SURVIVAL_UTIL_H_
9 #define XGBOOST_COMMON_SURVIVAL_UTIL_H_
10 
11 #include <xgboost/parameter.h>
12 #include <memory>
14 
16 
17 namespace xgboost {
18 namespace common {
19 
21 struct AFTParam : public XGBoostParameter<AFTParam> {
27  DMLC_DECLARE_FIELD(aft_loss_distribution)
29  .add_enum("normal", ProbabilityDistributionType::kNormal)
30  .add_enum("logistic", ProbabilityDistributionType::kLogistic)
31  .add_enum("extreme", ProbabilityDistributionType::kExtreme)
32  .describe("Choice of distribution for the noise term in "
33  "Accelerated Failure Time model");
34  DMLC_DECLARE_FIELD(aft_loss_distribution_scale)
35  .set_default(1.0f)
36  .describe("Scaling factor used to scale the distribution in "
37  "Accelerated Failure Time model");
38  }
39 };
40 
42 class AFTLoss {
43  private:
44  std::unique_ptr<ProbabilityDistribution> dist_;
45  ProbabilityDistributionType dist_type_;
46 
47  public:
53  : dist_(ProbabilityDistribution::Create(dist_type)),
54  dist_type_(dist_type) {}
55 
56  public:
64  double Loss(double y_lower, double y_upper, double y_pred, double sigma);
72  double Gradient(double y_lower, double y_upper, double y_pred, double sigma);
80  double Hessian(double y_lower, double y_upper, double y_pred, double sigma);
81 };
82 
83 } // namespace common
84 } // namespace xgboost
85 
86 #endif // XGBOOST_COMMON_SURVIVAL_UTIL_H_
AFTLoss(ProbabilityDistributionType dist_type)
Constructor for AFT loss function.
Definition: survival_util.h:52
Definition: parameter.h:84
float aft_loss_distribution_scale
Scaling factor to be applied to the distribution.
Definition: survival_util.h:25
Interface for a probability distribution.
Definition: probability_distribution.h:29
ProbabilityDistributionType
Enum encoding possible choices of probability distribution.
Definition: probability_distribution.h:24
DMLC_DECLARE_PARAMETER(AFTParam)
Definition: survival_util.h:26
The AFT loss function.
Definition: survival_util.h:42
Parameter structure for AFT loss and metric.
Definition: survival_util.h:21
DECLARE_FIELD_ENUM_CLASS(xgboost::common::ProbabilityDistributionType)
namespace of xgboost
Definition: base.h:102
Implementation of a few useful probability distributions.
macro for using C++11 enum class as DMLC parameter
ProbabilityDistributionType aft_loss_distribution
Choice of probability distribution for the noise term in AFT.
Definition: survival_util.h:23