xgboost
gradient.h
Go to the documentation of this file.
1 
4 #pragma once
5 
6 #include <xgboost/base.h> // for GradientPair
7 #include <xgboost/linalg.h> // for Matrix
8 #include <xgboost/logging.h>
9 
10 #include <cstddef> // for size_t
11 
12 namespace xgboost {
21 
22  [[nodiscard]] bool HasValueGrad() const noexcept { return !value_gpair.Empty(); }
23 
24  [[nodiscard]] std::size_t NumSplitTargets() const noexcept { return gpair.Shape(1); }
25  [[nodiscard]] std::size_t NumTargets() const noexcept {
26  return HasValueGrad() ? value_gpair.Shape(1) : this->gpair.Shape(1);
27  }
28 
30  if (HasValueGrad()) {
31  return this->value_gpair.View(ctx->Device());
32  }
33  return this->gpair.View(ctx->Device());
34  }
35 
36  [[nodiscard]] linalg::Matrix<GradientPair> const* Grad() const { return &gpair; }
37  [[nodiscard]] linalg::Matrix<GradientPair>* Grad() { return &gpair; }
38 
39  [[nodiscard]] linalg::Matrix<GradientPair> const* FullGradOnly() const {
40  if (this->HasValueGrad()) {
41  LOG(FATAL) << "Reduced gradient is not yet supported.";
42  }
43  return this->Grad();
44  }
46  if (this->HasValueGrad()) {
47  LOG(FATAL) << "Reduced gradient is not yet supported.";
48  }
49  return this->Grad();
50  }
51 };
52 } // namespace xgboost
Defines configuration macros and basic types for xgboost.
A tensor view with static type and dimension. It implements indexing and slicing.
Definition: linalg.h:278
A tensor storage. To use it for other functionality like slicing one needs to obtain a view first....
Definition: linalg.h:760
auto View(DeviceOrd device)
Get a TensorView for this tensor.
Definition: linalg.h:855
auto Shape() const
Definition: linalg.h:882
Linear algebra related utilities.
Learner interface that integrates objective, gbm and evaluation together. This is the user facing XGB...
Definition: base.h:89
Runtime context for XGBoost. Contains information like threads and device.
Definition: context.h:142
DeviceOrd Device() const
Get the current device and ordinal.
Definition: context.h:207
Container for gradient produced by objective.
Definition: gradient.h:16
linalg::Matrix< GradientPair > * Grad()
Definition: gradient.h:37
bool HasValueGrad() const noexcept
Definition: gradient.h:22
linalg::Matrix< GradientPair > const * Grad() const
Definition: gradient.h:36
linalg::Matrix< GradientPair > value_gpair
Gradient used for tree leaf value, optional.
Definition: gradient.h:20
linalg::MatrixView< GradientPair const > ValueGrad(Context const *ctx) const
Definition: gradient.h:29
std::size_t NumSplitTargets() const noexcept
Definition: gradient.h:24
linalg::Matrix< GradientPair > gpair
Gradient used for multi-target tree split and linear model.
Definition: gradient.h:18
std::size_t NumTargets() const noexcept
Definition: gradient.h:25
linalg::Matrix< GradientPair > const * FullGradOnly() const
Definition: gradient.h:39
linalg::Matrix< GradientPair > * FullGradOnly()
Definition: gradient.h:45