xgboost
quantile.h
Go to the documentation of this file.
1 
7 #ifndef XGBOOST_COMMON_QUANTILE_H_
8 #define XGBOOST_COMMON_QUANTILE_H_
9 
10 #include <dmlc/base.h>
11 #include <xgboost/logging.h>
12 #include <xgboost/data.h>
13 #include <cmath>
14 #include <vector>
15 #include <cstring>
16 #include <algorithm>
17 #include <iostream>
18 #include <set>
19 
20 #include "timer.h"
21 
22 namespace xgboost {
23 namespace common {
29 template<typename DType, typename RType>
30 struct WQSummary {
32  struct Entry {
34  RType rmin;
36  RType rmax;
38  RType wmin;
40  DType value;
41  // constructor
42  XGBOOST_DEVICE Entry() {} // NOLINT
43  // constructor
44  XGBOOST_DEVICE Entry(RType rmin, RType rmax, RType wmin, DType value)
45  : rmin(rmin), rmax(rmax), wmin(wmin), value(value) {}
50  inline void CheckValid(RType eps = 0) const {
51  CHECK(rmin >= 0 && rmax >= 0 && wmin >= 0) << "nonneg constraint";
52  CHECK(rmax- rmin - wmin > -eps) << "relation constraint: min/max";
53  }
55  XGBOOST_DEVICE inline RType RMinNext() const {
56  return rmin + wmin;
57  }
59  XGBOOST_DEVICE inline RType RMaxPrev() const {
60  return rmax - wmin;
61  }
62 
63  friend std::ostream& operator<<(std::ostream& os, Entry const& e) {
64  os << "rmin: " << e.rmin << ", "
65  << "rmax: " << e.rmax << ", "
66  << "wmin: " << e.wmin << ", "
67  << "value: " << e.value;
68  return os;
69  }
70  };
72  struct Queue {
73  // entry in the queue
74  struct QEntry {
75  // value of the instance
76  DType value;
77  // weight of instance
78  RType weight;
79  // default constructor
80  QEntry() = default;
81  // constructor
82  QEntry(DType value, RType weight)
83  : value(value), weight(weight) {}
84  // comparator on value
85  inline bool operator<(const QEntry &b) const {
86  return value < b.value;
87  }
88  };
89  // the input queue
90  std::vector<QEntry> queue;
91  // end of the queue
92  size_t qtail;
93  // push data to the queue
94  inline void Push(DType x, RType w) {
95  if (qtail == 0 || queue[qtail - 1].value != x) {
96  queue[qtail++] = QEntry(x, w);
97  } else {
98  queue[qtail - 1].weight += w;
99  }
100  }
101  inline void MakeSummary(WQSummary *out) {
102  std::sort(queue.begin(), queue.begin() + qtail);
103  out->size = 0;
104  // start update sketch
105  RType wsum = 0;
106  // construct data with unique weights
107  for (size_t i = 0; i < qtail;) {
108  size_t j = i + 1;
109  RType w = queue[i].weight;
110  while (j < qtail && queue[j].value == queue[i].value) {
111  w += queue[j].weight; ++j;
112  }
113  out->data[out->size++] = Entry(wsum, wsum + w, w, queue[i].value);
114  wsum += w; i = j;
115  }
116  }
117  };
121  size_t size;
122  // constructor
124  : data(data), size(size) {}
128  inline RType MaxError() const {
129  RType res = data[0].rmax - data[0].rmin - data[0].wmin;
130  for (size_t i = 1; i < size; ++i) {
131  res = std::max(data[i].RMaxPrev() - data[i - 1].RMinNext(), res);
132  res = std::max(data[i].rmax - data[i].rmin - data[i].wmin, res);
133  }
134  return res;
135  }
141  inline Entry Query(DType qvalue, size_t &istart) const { // NOLINT(*)
142  while (istart < size && qvalue > data[istart].value) {
143  ++istart;
144  }
145  if (istart == size) {
146  RType rmax = data[size - 1].rmax;
147  return Entry(rmax, rmax, 0.0f, qvalue);
148  }
149  if (qvalue == data[istart].value) {
150  return data[istart];
151  } else {
152  if (istart == 0) {
153  return Entry(0.0f, 0.0f, 0.0f, qvalue);
154  } else {
155  return Entry(data[istart - 1].RMinNext(),
156  data[istart].RMaxPrev(),
157  0.0f, qvalue);
158  }
159  }
160  }
162  inline RType MaxRank() const {
163  return data[size - 1].rmax;
164  }
169  inline void CopyFrom(const WQSummary &src) {
170  if (!src.data) {
171  CHECK_EQ(src.size, 0);
172  size = 0;
173  return;
174  }
175  if (!data) {
176  CHECK_EQ(this->size, 0);
177  CHECK_EQ(src.size, 0);
178  return;
179  }
180  size = src.size;
181  std::memcpy(data, src.data, sizeof(Entry) * size);
182  }
183  inline void MakeFromSorted(const Entry* entries, size_t n) {
184  size = 0;
185  for (size_t i = 0; i < n;) {
186  size_t j = i + 1;
187  // ignore repeated values
188  for (; j < n && entries[j].value == entries[i].value; ++j) {}
189  data[size++] = Entry(entries[i].rmin, entries[i].rmax, entries[i].wmin,
190  entries[i].value);
191  i = j;
192  }
193  }
200  inline void CheckValid(RType eps) const {
201  for (size_t i = 0; i < size; ++i) {
202  data[i].CheckValid(eps);
203  if (i != 0) {
204  CHECK(data[i].rmin >= data[i - 1].rmin + data[i - 1].wmin) << "rmin range constraint";
205  CHECK(data[i].rmax >= data[i - 1].rmax + data[i].wmin) << "rmax range constraint";
206  }
207  }
208  }
209 
216  void SetPrune(const WQSummary &src, size_t maxsize) {
217  if (src.size <= maxsize) {
218  this->CopyFrom(src); return;
219  }
220  const RType begin = src.data[0].rmax;
221  const RType range = src.data[src.size - 1].rmin - src.data[0].rmax;
222  const size_t n = maxsize - 1;
223  data[0] = src.data[0];
224  this->size = 1;
225  // lastidx is used to avoid duplicated records
226  size_t i = 1, lastidx = 0;
227  for (size_t k = 1; k < n; ++k) {
228  RType dx2 = 2 * ((k * range) / n + begin);
229  // find first i such that d < (rmax[i+1] + rmin[i+1]) / 2
230  while (i < src.size - 1
231  && dx2 >= src.data[i + 1].rmax + src.data[i + 1].rmin) ++i;
232  if (i == src.size - 1) break;
233  if (dx2 < src.data[i].RMinNext() + src.data[i + 1].RMaxPrev()) {
234  if (i != lastidx) {
235  data[size++] = src.data[i]; lastidx = i;
236  }
237  } else {
238  if (i + 1 != lastidx) {
239  data[size++] = src.data[i + 1]; lastidx = i + 1;
240  }
241  }
242  }
243  if (lastidx != src.size - 1) {
244  data[size++] = src.data[src.size - 1];
245  }
246  }
252  inline void SetCombine(const WQSummary &sa,
253  const WQSummary &sb) {
254  if (sa.size == 0) {
255  this->CopyFrom(sb); return;
256  }
257  if (sb.size == 0) {
258  this->CopyFrom(sa); return;
259  }
260  CHECK(sa.size > 0 && sb.size > 0);
261  const Entry *a = sa.data, *a_end = sa.data + sa.size;
262  const Entry *b = sb.data, *b_end = sb.data + sb.size;
263  // extended rmin value
264  RType aprev_rmin = 0, bprev_rmin = 0;
265  Entry *dst = this->data;
266  while (a != a_end && b != b_end) {
267  // duplicated value entry
268  if (a->value == b->value) {
269  *dst = Entry(a->rmin + b->rmin,
270  a->rmax + b->rmax,
271  a->wmin + b->wmin, a->value);
272  aprev_rmin = a->RMinNext();
273  bprev_rmin = b->RMinNext();
274  ++dst; ++a; ++b;
275  } else if (a->value < b->value) {
276  *dst = Entry(a->rmin + bprev_rmin,
277  a->rmax + b->RMaxPrev(),
278  a->wmin, a->value);
279  aprev_rmin = a->RMinNext();
280  ++dst; ++a;
281  } else {
282  *dst = Entry(b->rmin + aprev_rmin,
283  b->rmax + a->RMaxPrev(),
284  b->wmin, b->value);
285  bprev_rmin = b->RMinNext();
286  ++dst; ++b;
287  }
288  }
289  if (a != a_end) {
290  RType brmax = (b_end - 1)->rmax;
291  do {
292  *dst = Entry(a->rmin + bprev_rmin, a->rmax + brmax, a->wmin, a->value);
293  ++dst; ++a;
294  } while (a != a_end);
295  }
296  if (b != b_end) {
297  RType armax = (a_end - 1)->rmax;
298  do {
299  *dst = Entry(b->rmin + aprev_rmin, b->rmax + armax, b->wmin, b->value);
300  ++dst; ++b;
301  } while (b != b_end);
302  }
303  this->size = dst - data;
304  const RType tol = 10;
305  RType err_mingap, err_maxgap, err_wgap;
306  this->FixError(&err_mingap, &err_maxgap, &err_wgap);
307  if (err_mingap > tol || err_maxgap > tol || err_wgap > tol) {
308  LOG(INFO) << "mingap=" << err_mingap
309  << ", maxgap=" << err_maxgap
310  << ", wgap=" << err_wgap;
311  }
312  CHECK(size <= sa.size + sb.size) << "bug in combine";
313  }
314  // helper function to print the current content of sketch
315  inline void Print() const {
316  for (size_t i = 0; i < this->size; ++i) {
317  LOG(CONSOLE) << "[" << i << "] rmin=" << data[i].rmin
318  << ", rmax=" << data[i].rmax
319  << ", wmin=" << data[i].wmin
320  << ", v=" << data[i].value;
321  }
322  }
323  // try to fix rounding error
324  // and re-establish invariance
325  inline void FixError(RType *err_mingap,
326  RType *err_maxgap,
327  RType *err_wgap) const {
328  *err_mingap = 0;
329  *err_maxgap = 0;
330  *err_wgap = 0;
331  RType prev_rmin = 0, prev_rmax = 0;
332  for (size_t i = 0; i < this->size; ++i) {
333  if (data[i].rmin < prev_rmin) {
334  data[i].rmin = prev_rmin;
335  *err_mingap = std::max(*err_mingap, prev_rmin - data[i].rmin);
336  } else {
337  prev_rmin = data[i].rmin;
338  }
339  if (data[i].rmax < prev_rmax) {
340  data[i].rmax = prev_rmax;
341  *err_maxgap = std::max(*err_maxgap, prev_rmax - data[i].rmax);
342  }
343  RType rmin_next = data[i].RMinNext();
344  if (data[i].rmax < rmin_next) {
345  data[i].rmax = rmin_next;
346  *err_wgap = std::max(*err_wgap, data[i].rmax - rmin_next);
347  }
348  prev_rmax = data[i].rmax;
349  }
350  }
351  // check consistency of the summary
352  inline bool Check(const char *msg) const {
353  const float tol = 10.0f;
354  for (size_t i = 0; i < this->size; ++i) {
355  if (data[i].rmin + data[i].wmin > data[i].rmax + tol ||
356  data[i].rmin < -1e-6f || data[i].rmax < -1e-6f) {
357  LOG(INFO) << "---------- WQSummary::Check did not pass ----------";
358  this->Print();
359  return false;
360  }
361  }
362  return true;
363  }
364 };
365 
367 template<typename DType, typename RType>
368 struct WXQSummary : public WQSummary<DType, RType> {
369  // redefine entry type
371  // constructor
373  : WQSummary<DType, RType>(data, size) {}
374  // check if the block is large chunk
375  inline static bool CheckLarge(const Entry &e, RType chunk) {
376  return e.RMinNext() > e.RMaxPrev() + chunk;
377  }
378  // set prune
379  inline void SetPrune(const WQSummary<DType, RType> &src, size_t maxsize) {
380  if (src.size <= maxsize) {
381  this->CopyFrom(src); return;
382  }
383  RType begin = src.data[0].rmax;
384  // n is number of points exclude the min/max points
385  size_t n = maxsize - 2, nbig = 0;
386  // these is the range of data exclude the min/max point
387  RType range = src.data[src.size - 1].rmin - begin;
388  // prune off zero weights
389  if (range == 0.0f || maxsize <= 2) {
390  // special case, contain only two effective data pts
391  this->data[0] = src.data[0];
392  this->data[1] = src.data[src.size - 1];
393  this->size = 2;
394  return;
395  } else {
396  range = std::max(range, static_cast<RType>(1e-3f));
397  }
398  // Get a big enough chunk size, bigger than range / n
399  // (multiply by 2 is a safe factor)
400  const RType chunk = 2 * range / n;
401  // minimized range
402  RType mrange = 0;
403  {
404  // first scan, grab all the big chunk
405  // moving block index, exclude the two ends.
406  size_t bid = 0;
407  for (size_t i = 1; i < src.size - 1; ++i) {
408  // detect big chunk data point in the middle
409  // always save these data points.
410  if (CheckLarge(src.data[i], chunk)) {
411  if (bid != i - 1) {
412  // accumulate the range of the rest points
413  mrange += src.data[i].RMaxPrev() - src.data[bid].RMinNext();
414  }
415  bid = i; ++nbig;
416  }
417  }
418  if (bid != src.size - 2) {
419  mrange += src.data[src.size-1].RMaxPrev() - src.data[bid].RMinNext();
420  }
421  }
422  // assert: there cannot be more than n big data points
423  if (nbig >= n) {
424  // see what was the case
425  LOG(INFO) << " check quantile stats, nbig=" << nbig << ", n=" << n;
426  LOG(INFO) << " srcsize=" << src.size << ", maxsize=" << maxsize
427  << ", range=" << range << ", chunk=" << chunk;
428  src.Print();
429  CHECK(nbig < n) << "quantile: too many large chunk";
430  }
431  this->data[0] = src.data[0];
432  this->size = 1;
433  // The counter on the rest of points, to be selected equally from small chunks.
434  n = n - nbig;
435  // find the rest of point
436  size_t bid = 0, k = 1, lastidx = 0;
437  for (size_t end = 1; end < src.size; ++end) {
438  if (end == src.size - 1 || CheckLarge(src.data[end], chunk)) {
439  if (bid != end - 1) {
440  size_t i = bid;
441  RType maxdx2 = src.data[end].RMaxPrev() * 2;
442  for (; k < n; ++k) {
443  RType dx2 = 2 * ((k * mrange) / n + begin);
444  if (dx2 >= maxdx2) break;
445  while (i < end &&
446  dx2 >= src.data[i + 1].rmax + src.data[i + 1].rmin) ++i;
447  if (i == end) break;
448  if (dx2 < src.data[i].RMinNext() + src.data[i + 1].RMaxPrev()) {
449  if (i != lastidx) {
450  this->data[this->size++] = src.data[i]; lastidx = i;
451  }
452  } else {
453  if (i + 1 != lastidx) {
454  this->data[this->size++] = src.data[i + 1]; lastidx = i + 1;
455  }
456  }
457  }
458  }
459  if (lastidx != end) {
460  this->data[this->size++] = src.data[end];
461  lastidx = end;
462  }
463  bid = end;
464  // shift base by the gap
465  begin += src.data[bid].RMinNext() - src.data[bid].RMaxPrev();
466  }
467  }
468  }
469 };
477 template<typename DType, typename RType, class TSummary>
479  public:
480  static float constexpr kFactor = 8.0;
481 
482  public:
484  using Summary = TSummary;
486  using Entry = typename Summary::Entry;
488  struct SummaryContainer : public Summary {
489  std::vector<Entry> space;
490  SummaryContainer(const SummaryContainer &src) : Summary(nullptr, src.size) {
491  this->space = src.space;
492  this->data = dmlc::BeginPtr(this->space);
493  }
494  SummaryContainer() : Summary(nullptr, 0) {
495  }
497  inline void Reserve(size_t size) {
498  if (size > space.size()) {
499  space.resize(size);
500  this->data = dmlc::BeginPtr(space);
501  }
502  }
509  inline void Reduce(const Summary &src, size_t max_nbyte) {
510  this->Reserve((max_nbyte - sizeof(this->size)) / sizeof(Entry));
512  temp.Reserve(this->size + src.size);
513  temp.SetCombine(*this, src);
514  this->SetPrune(temp, space.size());
515  }
517  inline static size_t CalcMemCost(size_t nentry) {
518  return sizeof(size_t) + sizeof(Entry) * nentry;
519  }
521  template<typename TStream>
522  inline void Save(TStream &fo) const { // NOLINT(*)
523  fo.Write(&(this->size), sizeof(this->size));
524  if (this->size != 0) {
525  fo.Write(this->data, this->size * sizeof(Entry));
526  }
527  }
529  template<typename TStream>
530  inline void Load(TStream &fi) { // NOLINT(*)
531  CHECK_EQ(fi.Read(&this->size, sizeof(this->size)), sizeof(this->size));
532  this->Reserve(this->size);
533  if (this->size != 0) {
534  CHECK_EQ(fi.Read(this->data, this->size * sizeof(Entry)),
535  this->size * sizeof(Entry));
536  }
537  }
538  };
544  inline void Init(size_t maxn, double eps) {
545  LimitSizeLevel(maxn, eps, &nlevel, &limit_size);
546  // lazy reserve the space, if there is only one value, no need to allocate space
547  inqueue.queue.resize(1);
548  inqueue.qtail = 0;
549  data.clear();
550  level.clear();
551  }
552 
553  inline static void LimitSizeLevel
554  (size_t maxn, double eps, size_t* out_nlevel, size_t* out_limit_size) {
555  size_t& nlevel = *out_nlevel;
556  size_t& limit_size = *out_limit_size;
557  nlevel = 1;
558  while (true) {
559  limit_size = static_cast<size_t>(ceil(nlevel / eps)) + 1;
560  limit_size = std::min(maxn, limit_size);
561  size_t n = (1ULL << nlevel);
562  if (n * limit_size >= maxn) break;
563  ++nlevel;
564  }
565  // check invariant
566  size_t n = (1ULL << nlevel);
567  CHECK(n * limit_size >= maxn) << "invalid init parameter";
568  CHECK(nlevel <= std::max(static_cast<size_t>(1), static_cast<size_t>(limit_size * eps)))
569  << "invalid init parameter";
570  }
571 
577  inline void Push(DType x, RType w = 1) {
578  if (w == static_cast<RType>(0)) return;
579  if (inqueue.qtail == inqueue.queue.size() && inqueue.queue[inqueue.qtail - 1].value != x) {
580  // jump from lazy one value to limit_size * 2
581  if (inqueue.queue.size() == 1) {
582  inqueue.queue.resize(limit_size * 2);
583  } else {
584  temp.Reserve(limit_size * 2);
585  inqueue.MakeSummary(&temp);
586  // cleanup queue
587  inqueue.qtail = 0;
588  this->PushTemp();
589  }
590  }
591  inqueue.Push(x, w);
592  }
593 
594  inline void PushSummary(const Summary& summary) {
595  temp.Reserve(limit_size * 2);
596  temp.SetPrune(summary, limit_size * 2);
597  PushTemp();
598  }
599 
601  inline void PushTemp() {
602  temp.Reserve(limit_size * 2);
603  for (size_t l = 1; true; ++l) {
604  this->InitLevel(l + 1);
605  // check if level l is empty
606  if (level[l].size == 0) {
607  level[l].SetPrune(temp, limit_size);
608  break;
609  } else {
610  // level 0 is actually temp space
611  level[0].SetPrune(temp, limit_size);
612  temp.SetCombine(level[0], level[l]);
613  if (temp.size > limit_size) {
614  // try next level
615  level[l].size = 0;
616  } else {
617  // if merged record is still smaller, no need to send to next level
618  level[l].CopyFrom(temp); break;
619  }
620  }
621  }
622  }
624  inline void GetSummary(SummaryContainer *out) {
625  if (level.size() != 0) {
626  out->Reserve(limit_size * 2);
627  } else {
628  out->Reserve(inqueue.queue.size());
629  }
630  inqueue.MakeSummary(out);
631  if (level.size() != 0) {
632  level[0].SetPrune(*out, limit_size);
633  for (size_t l = 1; l < level.size(); ++l) {
634  if (level[l].size == 0) continue;
635  if (level[0].size == 0) {
636  level[0].CopyFrom(level[l]);
637  } else {
638  out->SetCombine(level[0], level[l]);
639  level[0].SetPrune(*out, limit_size);
640  }
641  }
642  out->CopyFrom(level[0]);
643  } else {
644  if (out->size > limit_size) {
646  temp.SetPrune(*out, limit_size);
647  out->CopyFrom(temp);
648  }
649  }
650  }
651  // used for debug, check if the sketch is valid
652  inline void CheckValid(RType eps) const {
653  for (size_t l = 1; l < level.size(); ++l) {
654  level[l].CheckValid(eps);
655  }
656  }
657  // initialize level space to at least nlevel
658  inline void InitLevel(size_t nlevel) {
659  if (level.size() >= nlevel) return;
660  data.resize(limit_size * nlevel);
661  level.resize(nlevel, Summary(nullptr, 0));
662  for (size_t l = 0; l < level.size(); ++l) {
663  level[l].data = dmlc::BeginPtr(data) + l * limit_size;
664  }
665  }
666  // input data queue
667  typename Summary::Queue inqueue;
668  // number of levels
669  size_t nlevel;
670  // size of summary in each level
671  size_t limit_size;
672  // the level of each summaries
673  std::vector<Summary> level;
674  // content of the summary
675  std::vector<Entry> data;
676  // temporal summary, used for temp-merge
677  SummaryContainer temp;
678 };
679 
685 template<typename DType, typename RType = unsigned>
687  public QuantileSketchTemplate<DType, RType, WQSummary<DType, RType> > {
688 };
689 
695 template<typename DType, typename RType = unsigned>
697  public QuantileSketchTemplate<DType, RType, WXQSummary<DType, RType> > {
698 };
699 
700 class HistogramCuts;
701 
705 template <typename WQSketch>
707  protected:
708  std::vector<WQSketch> sketches_;
709  std::vector<std::set<float>> categories_;
710  std::vector<FeatureType> const feature_types_;
711 
712  std::vector<bst_row_t> columns_size_;
713  int32_t max_bins_;
714  bool use_group_ind_{false};
715  int32_t n_threads_;
716  bool has_categorical_{false};
718 
719  public:
720  /* \brief Initialize necessary info.
721  *
722  * \param columns_size Size of each column.
723  * \param max_bins maximum number of bins for each feature.
724  * \param use_group whether is assigned to group to data instance.
725  */
726  SketchContainerImpl(std::vector<bst_row_t> columns_size, int32_t max_bins,
727  common::Span<FeatureType const> feature_types, bool use_group,
728  int32_t n_threads);
729 
730  static bool UseGroup(MetaInfo const &info) {
731  size_t const num_groups =
732  info.group_ptr_.size() == 0 ? 0 : info.group_ptr_.size() - 1;
733  // Use group index for weights?
734  bool const use_group_ind =
735  num_groups != 0 && (info.weights_.Size() != info.num_row_);
736  return use_group_ind;
737  }
738 
739  static std::vector<bst_row_t> CalcColumnSize(SparsePage const &page,
740  bst_feature_t const n_columns,
741  size_t const nthreads);
742 
743  static std::vector<bst_feature_t> LoadBalance(SparsePage const &page,
744  bst_feature_t n_columns,
745  size_t const nthreads);
746 
747  static uint32_t SearchGroupIndFromRow(std::vector<bst_uint> const &group_ptr,
748  size_t const base_rowid) {
749  CHECK_LT(base_rowid, group_ptr.back())
750  << "Row: " << base_rowid << " is not found in any group.";
751  bst_group_t group_ind =
752  std::upper_bound(group_ptr.cbegin(), group_ptr.cend() - 1, base_rowid) -
753  group_ptr.cbegin() - 1;
754  return group_ind;
755  }
756  // Gather sketches from all workers.
757  void GatherSketchInfo(std::vector<typename WQSketch::SummaryContainer> const &reduced,
758  std::vector<bst_row_t> *p_worker_segments,
759  std::vector<bst_row_t> *p_sketches_scan,
760  std::vector<typename WQSketch::Entry> *p_global_sketches);
761  // Merge sketches from all workers.
762  void AllReduce(std::vector<typename WQSketch::SummaryContainer> *p_reduced,
763  std::vector<int32_t> *p_num_cuts);
764 
765  /* \brief Push a CSR matrix. */
766  void PushRowPage(SparsePage const &page, MetaInfo const &info, Span<float const> hessian = {});
767 
768  void MakeCuts(HistogramCuts* cuts);
769 };
770 
771 class HostSketchContainer : public SketchContainerImpl<WQuantileSketch<float, float>> {
772  public:
774 
775  public:
776  HostSketchContainer(int32_t max_bins, MetaInfo const &info, std::vector<size_t> columns_size,
777  bool use_group, Span<float const> hessian, int32_t n_threads);
778 };
779 
785  double sum_total{0.0};
787  double rmin, wmin;
791  double next_goal;
792  // pointer to the sketch to put things in
794  // initialize the space
795  inline void Init(unsigned max_size) {
796  next_goal = -1.0f;
797  rmin = wmin = 0.0f;
798  sketch->temp.Reserve(max_size + 1);
799  sketch->temp.size = 0;
800  }
807  inline void Push(bst_float fvalue, bst_float w, unsigned max_size) {
808  if (next_goal == -1.0f) {
809  next_goal = 0.0f;
810  last_fvalue = fvalue;
811  wmin = w;
812  return;
813  }
814  if (last_fvalue != fvalue) {
815  double rmax = rmin + wmin;
816  if (rmax >= next_goal && sketch->temp.size != max_size) {
817  if (sketch->temp.size == 0 ||
818  last_fvalue > sketch->temp.data[sketch->temp.size - 1].value) {
819  // push to sketch
820  sketch->temp.data[sketch->temp.size] =
822  static_cast<bst_float>(rmin), static_cast<bst_float>(rmax),
823  static_cast<bst_float>(wmin), last_fvalue);
824  CHECK_LT(sketch->temp.size, max_size) << "invalid maximum size max_size=" << max_size
825  << ", stemp.size" << sketch->temp.size;
826  ++sketch->temp.size;
827  }
828  if (sketch->temp.size == max_size) {
829  next_goal = sum_total * 2.0f + 1e-5f;
830  } else {
831  next_goal = static_cast<bst_float>(sketch->temp.size * sum_total / max_size);
832  }
833  } else {
834  if (rmax >= next_goal) {
835  LOG(DEBUG) << "INFO: rmax=" << rmax << ", sum_total=" << sum_total
836  << ", naxt_goal=" << next_goal << ", size=" << sketch->temp.size;
837  }
838  }
839  rmin = rmax;
840  wmin = w;
841  last_fvalue = fvalue;
842  } else {
843  wmin += w;
844  }
845  }
846 
848  inline void Finalize(unsigned max_size) {
849  double rmax = rmin + wmin;
850  if (sketch->temp.size == 0 || last_fvalue > sketch->temp.data[sketch->temp.size - 1].value) {
851  CHECK_LE(sketch->temp.size, max_size)
852  << "Finalize: invalid maximum size, max_size=" << max_size
853  << ", stemp.size=" << sketch->temp.size;
854  // push to sketch
856  static_cast<bst_float>(rmin), static_cast<bst_float>(rmax), static_cast<bst_float>(wmin),
857  last_fvalue);
858  ++sketch->temp.size;
859  }
860  sketch->PushTemp();
861  }
862 };
863 
864 class SortedSketchContainer : public SketchContainerImpl<WXQuantileSketch<float, float>> {
865  std::vector<SortedQuantile> sketches_;
867 
868  public:
869  explicit SortedSketchContainer(int32_t max_bins, MetaInfo const &info,
870  std::vector<size_t> columns_size, bool use_group,
871  Span<float const> hessian, int32_t n_threads)
872  : SketchContainerImpl{columns_size, max_bins, info.feature_types.ConstHostSpan(), use_group,
873  n_threads} {
874  monitor_.Init(__func__);
875  sketches_.resize(info.num_col_);
876  size_t i = 0;
877  for (auto &sketch : sketches_) {
878  sketch.sketch = &Super::sketches_[i];
879  sketch.Init(max_bins_);
880  auto eps = 2.0 / max_bins;
881  sketch.sketch->Init(columns_size_[i], eps);
882  ++i;
883  }
884  }
888  void PushColPage(SparsePage const &page, MetaInfo const &info, Span<float const> hessian);
889 };
890 } // namespace common
891 } // namespace xgboost
892 #endif // XGBOOST_COMMON_QUANTILE_H_
xgboost::common::WQSummary::CopyFrom
void CopyFrom(const WQSummary &src)
copy content from src
Definition: quantile.h:169
xgboost::common::WQSummary::Queue::MakeSummary
void MakeSummary(WQSummary *out)
Definition: quantile.h:101
xgboost::MetaInfo::num_row_
uint64_t num_row_
number of rows in the data
Definition: data.h:52
xgboost::common::WQSummary::Check
bool Check(const char *msg) const
Definition: quantile.h:352
xgboost::common::QuantileSketchTemplate::LimitSizeLevel
static void LimitSizeLevel(size_t maxn, double eps, size_t *out_nlevel, size_t *out_limit_size)
Definition: quantile.h:554
xgboost::common::WQSummary::FixError
void FixError(RType *err_mingap, RType *err_maxgap, RType *err_wgap) const
Definition: quantile.h:325
xgboost::common::QuantileSketchTemplate::SummaryContainer::Reserve
void Reserve(size_t size)
reserve space for summary
Definition: quantile.h:497
xgboost::common::QuantileSketchTemplate::CheckValid
void CheckValid(RType eps) const
Definition: quantile.h:652
xgboost::common::HostSketchContainer
Definition: quantile.h:771
xgboost::common::SketchContainerImpl::GatherSketchInfo
void GatherSketchInfo(std::vector< typename WQSketch::SummaryContainer > const &reduced, std::vector< bst_row_t > *p_worker_segments, std::vector< bst_row_t > *p_sketches_scan, std::vector< typename WQSketch::Entry > *p_global_sketches)
xgboost::common::WQSummary::MakeFromSorted
void MakeFromSorted(const Entry *entries, size_t n)
Definition: quantile.h:183
xgboost::SparsePage
In-memory storage unit of sparse batch, stored in CSR format.
Definition: data.h:271
xgboost::common::WXQSummary::WXQSummary
WXQSummary(Entry *data, size_t size)
Definition: quantile.h:372
xgboost::common::SketchContainerImpl::n_threads_
int32_t n_threads_
Definition: quantile.h:715
xgboost::common::QuantileSketchTemplate::SummaryContainer::Save
void Save(TStream &fo) const
save the data structure into stream
Definition: quantile.h:522
xgboost::common::SortedQuantile::wmin
double wmin
Definition: quantile.h:787
xgboost::common::WQSummary::Entry::RMaxPrev
XGBOOST_DEVICE RType RMaxPrev() const
Definition: quantile.h:59
xgboost::common::HistogramCuts
Definition: hist_util.h:38
xgboost::common::SketchContainerImpl::feature_types_
const std::vector< FeatureType > feature_types_
Definition: quantile.h:710
xgboost::common::QuantileSketchTemplate::limit_size
size_t limit_size
Definition: quantile.h:671
xgboost::common::WQSummary::Queue::qtail
size_t qtail
Definition: quantile.h:92
xgboost::common::WQSummary::Entry::rmax
RType rmax
maximum rank
Definition: quantile.h:36
xgboost::common::SortedSketchContainer::PushColPage
void PushColPage(SparsePage const &page, MetaInfo const &info, Span< float const > hessian)
Push a sorted CSC page.
xgboost::common::QuantileSketchTemplate::level
std::vector< Summary > level
Definition: quantile.h:673
xgboost::common::WXQSummary
try to do efficient pruning
Definition: quantile.h:368
xgboost::common::QuantileSketchTemplate::SummaryContainer::CalcMemCost
static size_t CalcMemCost(size_t nentry)
return the number of bytes this data structure cost in serialization
Definition: quantile.h:517
xgboost::common::QuantileSketchTemplate::GetSummary
void GetSummary(SummaryContainer *out)
get the summary after finalize
Definition: quantile.h:624
xgboost::common::SketchContainerImpl::categories_
std::vector< std::set< float > > categories_
Definition: quantile.h:709
xgboost::common::QuantileSketchTemplate::SummaryContainer::space
std::vector< Entry > space
Definition: quantile.h:489
xgboost::common::WQSummary::Entry::CheckValid
void CheckValid(RType eps=0) const
debug function, check Valid
Definition: quantile.h:50
xgboost::common::WQuantileSketch
Quantile sketch use WQSummary.
Definition: quantile.h:686
xgboost::HostDeviceVector::ConstHostSpan
common::Span< T const > ConstHostSpan() const
Definition: host_device_vector.h:114
xgboost::common::WQSummary::Queue::Push
void Push(DType x, RType w)
Definition: quantile.h:94
xgboost::common::SortedQuantile::rmin
double rmin
statistics used in the sketch
Definition: quantile.h:787
xgboost::common::WQSummary::Print
void Print() const
Definition: quantile.h:315
xgboost::common::QuantileSketchTemplate::SummaryContainer
same as summary, but use STL to backup the space
Definition: quantile.h:488
xgboost::common::SketchContainerImpl::CalcColumnSize
static std::vector< bst_row_t > CalcColumnSize(SparsePage const &page, bst_feature_t const n_columns, size_t const nthreads)
xgboost::common::WQSummary::Entry::Entry
XGBOOST_DEVICE Entry(RType rmin, RType rmax, RType wmin, DType value)
Definition: quantile.h:44
xgboost::MetaInfo::group_ptr_
std::vector< bst_group_t > group_ptr_
the index of begin and end of a group needed when the learning task is ranking.
Definition: data.h:63
xgboost::common::SketchContainerImpl::AllReduce
void AllReduce(std::vector< typename WQSketch::SummaryContainer > *p_reduced, std::vector< int32_t > *p_num_cuts)
xgboost::common::SketchContainerImpl::columns_size_
std::vector< bst_row_t > columns_size_
Definition: quantile.h:712
xgboost::common::SortedSketchContainer
Definition: quantile.h:864
xgboost::common::HostSketchContainer::HostSketchContainer
HostSketchContainer(int32_t max_bins, MetaInfo const &info, std::vector< size_t > columns_size, bool use_group, Span< float const > hessian, int32_t n_threads)
xgboost::common::SketchContainerImpl::use_group_ind_
bool use_group_ind_
Definition: quantile.h:714
xgboost::common::QuantileSketchTemplate::Push
void Push(DType x, RType w=1)
add an element to a sketch
Definition: quantile.h:577
xgboost::common::SortedQuantile::sum_total
double sum_total
total sum of amount to be met
Definition: quantile.h:785
xgboost::common::SketchContainerImpl
Definition: quantile.h:706
xgboost::common::SortedQuantile::Finalize
void Finalize(unsigned max_size)
push final unfinished value to the sketch
Definition: quantile.h:848
xgboost::common::WQSummary::Queue::QEntry::operator<
bool operator<(const QEntry &b) const
Definition: quantile.h:85
xgboost::common::WQSummary::Entry::Entry
XGBOOST_DEVICE Entry()
Definition: quantile.h:42
xgboost::common::SketchContainerImpl::monitor_
Monitor monitor_
Definition: quantile.h:717
xgboost::bst_feature_t
uint32_t bst_feature_t
Type for data column (feature) index.
Definition: base.h:123
xgboost::common::QuantileSketchTemplate::InitLevel
void InitLevel(size_t nlevel)
Definition: quantile.h:658
xgboost::common::WQSummary::Entry::operator<<
friend std::ostream & operator<<(std::ostream &os, Entry const &e)
Definition: quantile.h:63
xgboost::bst_group_t
uint32_t bst_group_t
Type for ranking group index.
Definition: base.h:134
xgboost::common::WQSummary::Entry::rmin
RType rmin
minimum rank
Definition: quantile.h:34
xgboost::common::SketchContainerImpl::SearchGroupIndFromRow
static uint32_t SearchGroupIndFromRow(std::vector< bst_uint > const &group_ptr, size_t const base_rowid)
Definition: quantile.h:747
xgboost::MetaInfo::feature_types
HostDeviceVector< FeatureType > feature_types
Definition: data.h:92
xgboost::common::WXQSummary::CheckLarge
static bool CheckLarge(const Entry &e, RType chunk)
Definition: quantile.h:375
timer.h
xgboost::common::WQSummary::Queue::QEntry::weight
RType weight
Definition: quantile.h:78
xgboost::common::SortedSketchContainer::SortedSketchContainer
SortedSketchContainer(int32_t max_bins, MetaInfo const &info, std::vector< size_t > columns_size, bool use_group, Span< float const > hessian, int32_t n_threads)
Definition: quantile.h:869
xgboost::common::WQSummary::MaxRank
RType MaxRank() const
Definition: quantile.h:162
xgboost::common::WQSummary::CheckValid
void CheckValid(RType eps) const
debug function, validate whether the summary run consistency check to check if it is a valid summary
Definition: quantile.h:200
xgboost::common::Monitor
Timing utility used to measure total method execution time over the lifetime of the containing object...
Definition: timer.h:47
xgboost::common::WQSummary::Entry
an entry in the sketch summary
Definition: quantile.h:32
xgboost::common::QuantileSketchTemplate::kFactor
static constexpr float kFactor
Definition: quantile.h:480
xgboost::common::SketchContainerImpl::SketchContainerImpl
SketchContainerImpl(std::vector< bst_row_t > columns_size, int32_t max_bins, common::Span< FeatureType const > feature_types, bool use_group, int32_t n_threads)
xgboost::common::SketchContainerImpl::UseGroup
static bool UseGroup(MetaInfo const &info)
Definition: quantile.h:730
xgboost::common::WQSummary::Queue::QEntry::QEntry
QEntry(DType value, RType weight)
Definition: quantile.h:82
xgboost::common::SortedQuantile::sketch
common::WXQuantileSketch< bst_float, bst_float > * sketch
Definition: quantile.h:793
xgboost::common::QuantileSketchTemplate::Summary
TSummary Summary
type of summary type
Definition: quantile.h:484
xgboost::common::QuantileSketchTemplate< DType, unsigned, WQSummary< DType, unsigned > >::Entry
typename Summary::Entry Entry
the entry type
Definition: quantile.h:486
xgboost::HostDeviceVector::Size
size_t Size() const
xgboost::common::WQSummary::SetCombine
void SetCombine(const WQSummary &sa, const WQSummary &sb)
set current summary to be merged summary of sa and sb
Definition: quantile.h:252
xgboost::common::QuantileSketchTemplate::SummaryContainer::SummaryContainer
SummaryContainer(const SummaryContainer &src)
Definition: quantile.h:490
xgboost::common::WQSummary::WQSummary
WQSummary(Entry *data, size_t size)
Definition: quantile.h:123
xgboost::common::WQSummary::Entry::value
DType value
the value of data
Definition: quantile.h:40
xgboost::common::SortedQuantile::next_goal
double next_goal
current size of sketch
Definition: quantile.h:791
xgboost::common::SortedQuantile
Quantile structure accepts sorted data, extracted from histmaker.
Definition: quantile.h:783
xgboost::MetaInfo::weights_
HostDeviceVector< bst_float > weights_
weights of each instance, optional
Definition: data.h:65
xgboost::common::SketchContainerImpl::MakeCuts
void MakeCuts(HistogramCuts *cuts)
xgboost::common::QuantileSketchTemplate
template for all quantile sketch algorithm that uses merge/prune scheme
Definition: quantile.h:478
xgboost::common::QuantileSketchTemplate::SummaryContainer::Reduce
void Reduce(const Summary &src, size_t max_nbyte)
do elementwise combination of summary array this[i] = combine(this[i], src[i]) for each i
Definition: quantile.h:509
xgboost::common::WQSummary::Queue
input data queue before entering the summary
Definition: quantile.h:72
xgboost::common::Span
span class implementation, based on ISO++20 span<T>. The interface should be the same.
Definition: span.h:148
xgboost::common::WXQSummary::SetPrune
void SetPrune(const WQSummary< DType, RType > &src, size_t maxsize)
Definition: quantile.h:379
xgboost::common::QuantileSketchTemplate::PushSummary
void PushSummary(const Summary &summary)
Definition: quantile.h:594
xgboost::common::WQSummary::Entry::wmin
RType wmin
maximum weight
Definition: quantile.h:38
data.h
The input data structure of xgboost.
xgboost::common::WQSummary::size
size_t size
number of elements in the summary
Definition: quantile.h:121
xgboost::common::QuantileSketchTemplate::temp
SummaryContainer temp
Definition: quantile.h:677
xgboost::common::WQSummary
experimental wsummary
Definition: quantile.h:30
xgboost::common::SortedQuantile::last_fvalue
bst_float last_fvalue
last seen feature value
Definition: quantile.h:789
xgboost::common::SketchContainerImpl::PushRowPage
void PushRowPage(SparsePage const &page, MetaInfo const &info, Span< float const > hessian={})
xgboost::common::WQSummary::MaxError
RType MaxError() const
Definition: quantile.h:128
xgboost::common::Monitor::Init
void Init(std::string label)
Definition: timer.h:80
xgboost::common::QuantileSketchTemplate::PushTemp
void PushTemp()
push up temp
Definition: quantile.h:601
xgboost::common::WQSummary::Query
Entry Query(DType qvalue, size_t &istart) const
query qvalue, start from istart
Definition: quantile.h:141
xgboost::common::QuantileSketchTemplate::data
std::vector< Entry > data
Definition: quantile.h:675
xgboost::common::WQSummary::data
Entry * data
data field
Definition: quantile.h:119
xgboost::common::SortedQuantile::Push
void Push(bst_float fvalue, bst_float w, unsigned max_size)
push a new element to sketch
Definition: quantile.h:807
xgboost::common::WXQuantileSketch
Quantile sketch use WXQSummary.
Definition: quantile.h:696
xgboost::common::QuantileSketchTemplate::inqueue
Summary::Queue inqueue
Definition: quantile.h:667
xgboost::MetaInfo
Meta information about dataset, always sit in memory.
Definition: data.h:46
xgboost::common::WQSummary::Entry::RMinNext
XGBOOST_DEVICE RType RMinNext() const
Definition: quantile.h:55
xgboost::common::SketchContainerImpl::sketches_
std::vector< WQSketch > sketches_
Definition: quantile.h:708
xgboost::common::WQSummary::SetPrune
void SetPrune(const WQSummary &src, size_t maxsize)
set current summary to be pruned summary of src assume data field is already allocated to be at least...
Definition: quantile.h:216
xgboost::common::WXQSummary::Entry
typename WQSummary< DType, RType >::Entry Entry
Definition: quantile.h:370
xgboost::common::WQSummary::Queue::QEntry
Definition: quantile.h:74
xgboost::common::QuantileSketchTemplate::SummaryContainer::SummaryContainer
SummaryContainer()
Definition: quantile.h:494
xgboost::common::SortedQuantile::Init
void Init(unsigned max_size)
Definition: quantile.h:795
xgboost::common::QuantileSketchTemplate::SummaryContainer::Load
void Load(TStream &fi)
load data structure from input stream
Definition: quantile.h:530
xgboost::common::SketchContainerImpl::max_bins_
int32_t max_bins_
Definition: quantile.h:713
xgboost::common::WQSummary::Queue::QEntry::QEntry
QEntry()=default
XGBOOST_DEVICE
#define XGBOOST_DEVICE
Tag function as usable by device.
Definition: base.h:84
xgboost::common::WQSummary::Queue::QEntry::value
DType value
Definition: quantile.h:76
xgboost::common::QuantileSketchTemplate::nlevel
size_t nlevel
Definition: quantile.h:669
xgboost::common::SketchContainerImpl::LoadBalance
static std::vector< bst_feature_t > LoadBalance(SparsePage const &page, bst_feature_t n_columns, size_t const nthreads)
xgboost::common::QuantileSketchTemplate::Init
void Init(size_t maxn, double eps)
initialize the quantile sketch, given the performance specification
Definition: quantile.h:544
xgboost::common::WQSummary::Queue::queue
std::vector< QEntry > queue
Definition: quantile.h:90
xgboost::common::SketchContainerImpl::has_categorical_
bool has_categorical_
Definition: quantile.h:716
xgboost
namespace of xgboost
Definition: base.h:110
xgboost::bst_float
float bst_float
float type, used for storing statistics
Definition: base.h:119