16 #include <unordered_map> 35 std::shared_ptr<DMatrix>
data;
61 std::shared_ptr<std::unordered_map<DMatrix*, PredictionCacheEntry>>
cache_;
63 std::unordered_map<DMatrix*, PredictionCacheEntry>::iterator
FindCache(
DMatrix const* dmat) {
64 auto cache_emtry = std::find_if(
65 cache_->begin(), cache_->end(),
66 [dmat](std::pair<DMatrix *, PredictionCacheEntry const &>
const &kv) {
67 return kv.second.data.get() == dmat;
74 std::shared_ptr<std::unordered_map<DMatrix*, PredictionCacheEntry>> cache) :
75 generic_param_{generic_param}, cache_{cache} {}
83 virtual void Configure(
const std::vector<std::pair<std::string, std::string>>& cfg);
98 const gbm::GBTreeModel& model,
int tree_begin,
99 unsigned ntree_limit = 0) = 0;
116 virtual void UpdatePredictionCache(
117 const gbm::GBTreeModel& model,
118 std::vector<std::unique_ptr<TreeUpdater>>* updaters,
119 int num_new_trees) = 0;
137 std::vector<bst_float>* out_preds,
138 const gbm::GBTreeModel& model,
139 unsigned ntree_limit = 0) = 0;
155 virtual void PredictLeaf(
DMatrix* dmat, std::vector<bst_float>* out_preds,
156 const gbm::GBTreeModel& model,
157 unsigned ntree_limit = 0) = 0;
178 virtual void PredictContribution(
DMatrix* dmat,
179 std::vector<bst_float>* out_contribs,
180 const gbm::GBTreeModel& model,
181 unsigned ntree_limit = 0,
182 std::vector<bst_float>* tree_weights =
nullptr,
183 bool approximate =
false,
185 unsigned condition_feature = 0) = 0;
187 virtual void PredictInteractionContributions(
DMatrix* dmat,
188 std::vector<bst_float>* out_contribs,
189 const gbm::GBTreeModel& model,
190 unsigned ntree_limit = 0,
191 std::vector<bst_float>* tree_weights =
nullptr,
192 bool approximate =
false) = 0;
204 std::shared_ptr<std::unordered_map<DMatrix*, PredictionCacheEntry>> cache);
211 :
public dmlc::FunctionRegEntryBase<
212 PredictorReg, std::function<Predictor*(
213 GenericParameter const*,
214 std::shared_ptr<std::unordered_map<DMatrix*, PredictionCacheEntry>>)>> {};
216 #define XGBOOST_REGISTER_PREDICTOR(UniqueId, Name) \ 217 static DMLC_ATTRIBUTE_UNUSED ::xgboost::PredictorReg& \ 218 __make_##PredictorReg##_##UniqueId##__ = \ 219 ::dmlc::Registry<::xgboost::PredictorReg>::Get()->__REGISTER__(Name) Predictor(GenericParameter const *generic_param, std::shared_ptr< std::unordered_map< DMatrix *, PredictionCacheEntry >> cache)
Definition: predictor.h:73
Performs prediction on individual training instances or batches of instances for GBTree. The predictor also manages a prediction cache associated with input matrices. If possible, it will use previously calculated predictions instead of calculating new predictions. Prediction functions all take a GBTreeModel and a DMatrix as input and output a vector of predictions. The predictor does not modify any state of the model itself.
Definition: predictor.h:51
The input data structure of xgboost.
std::shared_ptr< std::unordered_map< DMatrix *, PredictionCacheEntry > > cache_
Map of matrices and associated cached predictions to facilitate storing and looking up predictions...
Definition: predictor.h:61
Definition: generic_parameters.h:14
Internal data structured used by XGBoost during training. There are two ways to create a customized D...
Definition: data.h:428
Registry entry for predictor.
Definition: predictor.h:210
A device-and-host vector abstraction layer.
GenericParameter const * generic_param_
Definition: predictor.h:56
span class implementation, based on ISO++20 span<T>. The interface should be the same.
Definition: span.h:115
std::unordered_map< DMatrix *, PredictionCacheEntry >::iterator FindCache(DMatrix const *dmat)
Definition: predictor.h:63
HostDeviceVector< bst_float > predictions
Definition: predictor.h:36
namespace of xgboost
Definition: base.h:102
defines configuration macros of xgboost.
std::shared_ptr< DMatrix > data
Definition: predictor.h:35
Contains pointer to input matrix and associated cached predictions.
Definition: predictor.h:34