xgboost
span.h
Go to the documentation of this file.
1 
29 #ifndef XGBOOST_COMMON_SPAN_H_
30 #define XGBOOST_COMMON_SPAN_H_
31 
32 #include <xgboost/logging.h> // CHECK
33 
34 #include <cinttypes> // int64_t
35 #include <type_traits>
36 
54 #if defined(_MSC_VER) && _MSC_VER < 1910
55 
56 #define __span_noexcept
57 
58 #pragma push_macro("constexpr")
59 #define constexpr /*constexpr*/
60 
61 #else
62 
63 #define __span_noexcept noexcept
64 
65 #endif // defined(_MSC_VER) && _MSC_VER < 1910
66 
67 namespace xgboost {
68 namespace common {
69 
70 // Usual logging facility is not available inside device code.
71 // TODO(trivialfis): Make dmlc check more generic.
72 // assert is not supported in mac as of CUDA 10.0
73 #define KERNEL_CHECK(cond) \
74  do { \
75  if (!(cond)) { \
76  printf("\nKernel error:\n" \
77  "In: %s, \tline: %d\n" \
78  "\t%s\n\tExpecting: %s\n", \
79  __FILE__, __LINE__, __PRETTY_FUNCTION__, # cond); \
80  asm("trap;"); \
81  } \
82  } while (0); \
83 
84 #ifdef __CUDA_ARCH__
85 #define SPAN_CHECK KERNEL_CHECK
86 #else
87 #define SPAN_CHECK CHECK // check from dmlc
88 #endif // __CUDA_ARCH__
89 
90 namespace detail {
97 using ptrdiff_t = int64_t; // NOLINT
98 } // namespace detail
99 
100 #if defined(_MSC_VER) && _MSC_VER < 1910
101 constexpr const detail::ptrdiff_t dynamic_extent = -1; // NOLINT
102 #else
103 constexpr detail::ptrdiff_t dynamic_extent = -1; // NOLINT
104 #endif // defined(_MSC_VER) && _MSC_VER < 1910
105 
106 enum class byte : unsigned char {}; // NOLINT
107 
108 template <class ElementType, detail::ptrdiff_t Extent>
109 class Span;
110 
111 namespace detail {
112 
113 template <typename SpanType, bool IsConst>
115  using ElementType = typename SpanType::element_type;
116 
117  public:
118  using iterator_category = std::random_access_iterator_tag; // NOLINT
119  using value_type = typename std::remove_cv<ElementType>::type; // NOLINT
120  using difference_type = typename SpanType::index_type; // NOLINT
121 
122  using reference = typename std::conditional< // NOLINT
123  IsConst, const ElementType, ElementType>::type&;
124  using pointer = typename std::add_pointer<reference>::type; // NOLINT
125 
126  XGBOOST_DEVICE constexpr SpanIterator() : span_{nullptr}, index_{0} {}
127 
129  const SpanType* _span,
130  typename SpanType::index_type _idx) __span_noexcept :
131  span_(_span), index_(_idx) {}
132 
134  template <bool B, typename std::enable_if<!B && IsConst>::type* = nullptr>
135  XGBOOST_DEVICE constexpr SpanIterator( // NOLINT
137  : SpanIterator(other_.span_, other_.index_) {}
138 
140  SPAN_CHECK(index_ < span_->size());
141  return *(span_->data() + index_);
142  }
143 
145  SPAN_CHECK(index_ != span_->size());
146  return span_->data() + index_;
147  }
148 
150  SPAN_CHECK(0 <= index_ && index_ != span_->size());
151  index_++;
152  return *this;
153  }
154 
156  auto ret = *this;
157  ++(*this);
158  return ret;
159  }
160 
162  SPAN_CHECK(index_ != 0 && index_ <= span_->size());
163  index_--;
164  return *this;
165  }
166 
168  auto ret = *this;
169  --(*this);
170  return ret;
171  }
172 
174  auto ret = *this;
175  return ret += n;
176  }
177 
179  SPAN_CHECK((index_ + n) >= 0 && (index_ + n) <= span_->size());
180  index_ += n;
181  return *this;
182  }
183 
185  SPAN_CHECK(span_ == rhs.span_);
186  return index_ - rhs.index_;
187  }
188 
190  auto ret = *this;
191  return ret -= n;
192  }
193 
195  return *this += -n;
196  }
197 
198  // friends
199  XGBOOST_DEVICE constexpr friend bool operator==(
201  return _lhs.span_ == _rhs.span_ && _lhs.index_ == _rhs.index_;
202  }
203 
204  XGBOOST_DEVICE constexpr friend bool operator!=(
206  return !(_lhs == _rhs);
207  }
208 
209  XGBOOST_DEVICE constexpr friend bool operator<(
211  return _lhs.index_ < _rhs.index_;
212  }
213 
214  XGBOOST_DEVICE constexpr friend bool operator<=(
216  return !(_rhs < _lhs);
217  }
218 
219  XGBOOST_DEVICE constexpr friend bool operator>(
221  return _rhs < _lhs;
222  }
223 
224  XGBOOST_DEVICE constexpr friend bool operator>=(
226  return !(_rhs > _lhs);
227  }
228 
229  protected:
230  const SpanType *span_;
232 };
233 
234 
235 // It's tempting to use constexpr instead of structs to do the following meta
236 // programming. But remember that we are supporting MSVC 2013 here.
237 
245 template <detail::ptrdiff_t Extent,
246  detail::ptrdiff_t Offset,
247  detail::ptrdiff_t Count>
248 struct ExtentValue : public std::integral_constant<
249  detail::ptrdiff_t, Count != dynamic_extent ?
250  Count : (Extent != dynamic_extent ? Extent - Offset : Extent)> {};
251 
256 template <typename T, detail::ptrdiff_t Extent>
257 struct ExtentAsBytesValue : public std::integral_constant<
258  detail::ptrdiff_t,
259  Extent == dynamic_extent ?
260  Extent : static_cast<detail::ptrdiff_t>(sizeof(T) * Extent)> {};
261 
262 template <detail::ptrdiff_t From, detail::ptrdiff_t To>
263 struct IsAllowedExtentConversion : public std::integral_constant<
264  bool, From == To || From == dynamic_extent || To == dynamic_extent> {};
265 
266 template <class From, class To>
267 struct IsAllowedElementTypeConversion : public std::integral_constant<
268  bool, std::is_convertible<From(*)[], To(*)[]>::value> {};
269 
270 template <class T>
271 struct IsSpanOracle : std::false_type {};
272 
273 template <class T, detail::ptrdiff_t Extent>
274 struct IsSpanOracle<Span<T, Extent>> : std::true_type {};
275 
276 template <class T>
277 struct IsSpan : public IsSpanOracle<typename std::remove_cv<T>::type> {};
278 
279 // Re-implement std algorithms here to adopt CUDA.
280 template <typename T>
281 struct Less {
282  XGBOOST_DEVICE constexpr bool operator()(const T& _x, const T& _y) const {
283  return _x < _y;
284  }
285 };
286 
287 template <typename T>
288 struct Greater {
289  XGBOOST_DEVICE constexpr bool operator()(const T& _x, const T& _y) const {
290  return _x > _y;
291  }
292 };
293 
294 template <class InputIt1, class InputIt2,
295  class Compare =
297 XGBOOST_DEVICE bool LexicographicalCompare(InputIt1 first1, InputIt1 last1,
298  InputIt2 first2, InputIt2 last2) {
299  Compare comp;
300  for (; first1 != last1 && first2 != last2; ++first1, ++first2) {
301  if (comp(*first1, *first2)) {
302  return true;
303  }
304  if (comp(*first2, *first1)) {
305  return false;
306  }
307  }
308  return first1 == last1 && first2 != last2;
309 }
310 
311 } // namespace detail
312 
313 
381 template <typename T,
382  detail::ptrdiff_t Extent = dynamic_extent>
383 class Span {
384  public:
385  using element_type = T; // NOLINT
386  using value_type = typename std::remove_cv<T>::type; // NOLINT
387  using index_type = detail::ptrdiff_t; // NOLINT
389  using pointer = T*; // NOLINT
390  using reference = T&; // NOLINT
391 
393  using const_iterator = const detail::SpanIterator<Span<T, Extent>, true>; // NOLINT
396 
397  // constructors
398 
399  XGBOOST_DEVICE constexpr Span() __span_noexcept : size_(0), data_(nullptr) {}
400 
402  size_(_count), data_(_ptr) {
403  SPAN_CHECK(_count >= 0);
404  SPAN_CHECK(_ptr || _count == 0);
405  }
406 
408  size_(_last - _first), data_(_first) {
409  SPAN_CHECK(size_ >= 0);
410  SPAN_CHECK(data_ || size_ == 0);
411  }
412 
413  template <std::size_t N>
414  XGBOOST_DEVICE constexpr Span(element_type (&arr)[N]) // NOLINT
415  __span_noexcept : size_(N), data_(&arr[0]) {}
416 
417  template <class Container,
418  class = typename std::enable_if<
419  !std::is_const<element_type>::value && !detail::IsSpan<Container>::value &&
420  std::is_convertible<typename Container::pointer,
421  pointer>::value &&
422  std::is_convertible<
423  typename Container::pointer,
424  decltype(std::declval<Container>().data())>::value>>
425  XGBOOST_DEVICE Span(Container& _cont) : // NOLINT
426  size_(_cont.size()), data_(_cont.data()) {}
427 
428  template <class Container,
429  class = typename std::enable_if<
430  std::is_const<element_type>::value && !detail::IsSpan<Container>::value &&
431  std::is_convertible<typename Container::pointer, pointer>::value &&
432  std::is_convertible<
433  typename Container::pointer,
434  decltype(std::declval<Container>().data())>::value>>
435  XGBOOST_DEVICE Span(const Container& _cont) : size_(_cont.size()), // NOLINT
436  data_(_cont.data()) {}
437 
438  template <class U, detail::ptrdiff_t OtherExtent,
439  class = typename std::enable_if<
442  XGBOOST_DEVICE constexpr Span(const Span<U, OtherExtent>& _other) // NOLINT
443  __span_noexcept : size_(_other.size()), data_(_other.data()) {}
444 
445  XGBOOST_DEVICE constexpr Span(const Span& _other)
446  __span_noexcept : size_(_other.size()), data_(_other.data()) {}
447 
449  size_ = _other.size();
450  data_ = _other.data();
451  return *this;
452  }
453 
455 
456  XGBOOST_DEVICE constexpr iterator begin() const __span_noexcept { // NOLINT
457  return {this, 0};
458  }
459 
460  XGBOOST_DEVICE constexpr iterator end() const __span_noexcept { // NOLINT
461  return {this, size()};
462  }
463 
464  XGBOOST_DEVICE constexpr const_iterator cbegin() const __span_noexcept { // NOLINT
465  return {this, 0};
466  }
467 
468  XGBOOST_DEVICE constexpr const_iterator cend() const __span_noexcept { // NOLINT
469  return {this, size()};
470  }
471 
473  return reverse_iterator{end()};
474  }
475 
476  XGBOOST_DEVICE constexpr reverse_iterator rend() const __span_noexcept { // NOLINT
477  return reverse_iterator{begin()};
478  }
479 
481  return const_reverse_iterator{cend()};
482  }
483 
485  return const_reverse_iterator{cbegin()};
486  }
487 
489  SPAN_CHECK(_idx >= 0 && _idx < size());
490  return data()[_idx];
491  }
492 
494  return this->operator[](_idx);
495  }
496 
497  XGBOOST_DEVICE constexpr pointer data() const __span_noexcept { // NOLINT
498  return data_;
499  }
500 
501  // Observers
502  XGBOOST_DEVICE constexpr index_type size() const __span_noexcept { // NOLINT
503  return size_;
504  }
505  XGBOOST_DEVICE constexpr index_type size_bytes() const __span_noexcept { // NOLINT
506  return size() * sizeof(T);
507  }
508 
509  XGBOOST_DEVICE constexpr bool empty() const __span_noexcept { // NOLINT
510  return size() == 0;
511  }
512 
513  // Subviews
514  template <detail::ptrdiff_t Count >
516  SPAN_CHECK(Count >= 0 && Count <= size());
517  return {data(), Count};
518  }
519 
521  detail::ptrdiff_t _count) const {
522  SPAN_CHECK(_count >= 0 && _count <= size());
523  return {data(), _count};
524  }
525 
526  template <detail::ptrdiff_t Count >
528  SPAN_CHECK(Count >=0 && size() - Count >= 0);
529  return {data() + size() - Count, Count};
530  }
531 
533  detail::ptrdiff_t _count) const {
534  SPAN_CHECK(_count >= 0 && _count <= size());
535  return subspan(size() - _count, _count);
536  }
537 
542  template <detail::ptrdiff_t Offset,
543  detail::ptrdiff_t Count = dynamic_extent>
544  XGBOOST_DEVICE auto subspan() const -> // NOLINT
546  detail::ExtentValue<Extent, Offset, Count>::value> {
547  SPAN_CHECK(Offset >= 0 && (Offset < size() || size() == 0));
548  SPAN_CHECK(Count == dynamic_extent ||
549  Count >= 0 && Offset + Count <= size());
550 
551  return {data() + Offset, Count == dynamic_extent ? size() - Offset : Count};
552  }
553 
555  detail::ptrdiff_t _offset,
556  detail::ptrdiff_t _count = dynamic_extent) const {
557  SPAN_CHECK(_offset >= 0 && (_offset < size() || size() == 0));
558  SPAN_CHECK((_count == dynamic_extent) ||
559  (_count >= 0 && _offset + _count <= size()));
560 
561  return {data() + _offset, _count ==
562  dynamic_extent ? size() - _offset : _count};
563  }
564 
565  private:
566  index_type size_;
567  pointer data_;
568 };
569 
570 template <class T, detail::ptrdiff_t X, class U, detail::ptrdiff_t Y>
572  if (l.size() != r.size()) {
573  return false;
574  }
575  for (auto l_beg = l.cbegin(), r_beg = r.cbegin(); l_beg != l.cend();
576  ++l_beg, ++r_beg) {
577  if (*l_beg != *r_beg) {
578  return false;
579  }
580  }
581  return true;
582 }
583 
584 template <class T, detail::ptrdiff_t X, class U, detail::ptrdiff_t Y>
586  return !(l == r);
587 }
588 
589 template <class T, detail::ptrdiff_t X, class U, detail::ptrdiff_t Y>
590 XGBOOST_DEVICE constexpr bool operator<(Span<T, X> l, Span<U, Y> r) {
591  return detail::LexicographicalCompare(l.begin(), l.end(),
592  r.begin(), r.end());
593 }
594 
595 template <class T, detail::ptrdiff_t X, class U, detail::ptrdiff_t Y>
596 XGBOOST_DEVICE constexpr bool operator<=(Span<T, X> l, Span<U, Y> r) {
597  return !(l > r);
598 }
599 
600 template <class T, detail::ptrdiff_t X, class U, detail::ptrdiff_t Y>
603  typename Span<T, X>::iterator, typename Span<U, Y>::iterator,
605  r.begin(), r.end());
606 }
607 
608 template <class T, detail::ptrdiff_t X, class U, detail::ptrdiff_t Y>
610  return !(l < r);
611 }
612 
613 template <class T, detail::ptrdiff_t E>
616  return {reinterpret_cast<const byte*>(s.data()), s.size_bytes()};
617 }
618 
619 template <class T, detail::ptrdiff_t E>
622  return {reinterpret_cast<byte*>(s.data()), s.size_bytes()};
623 }
624 
625 } // namespace common
626 } // namespace xgboost
627 
628 #if defined(_MSC_VER) &&_MSC_VER < 1910
629 #undef constexpr
630 #pragma pop_macro("constexpr")
631 #undef __span_noexcept
632 #endif // _MSC_VER < 1910
633 
634 #endif // XGBOOST_COMMON_SPAN_H_
byte
Definition: span.h:106
XGBOOST_DEVICE constexpr bool operator>=(Span< T, X > l, Span< U, Y > r)
Definition: span.h:609
detail::ptrdiff_t index_
Definition: span.h:231
XGBOOST_DEVICE constexpr friend bool operator>=(SpanIterator _lhs, SpanIterator _rhs) __span_noexcept
Definition: span.h:224
typename std::remove_cv< ElementType >::type value_type
Definition: span.h:119
XGBOOST_DEVICE Span< element_type, Count > first() const
Definition: span.h:515
XGBOOST_DEVICE constexpr index_type size() const __span_noexcept
Definition: span.h:502
XGBOOST_DEVICE Span(pointer _first, pointer _last)
Definition: span.h:407
XGBOOST_DEVICE constexpr bool empty() const __span_noexcept
Definition: span.h:509
XGBOOST_DEVICE SpanIterator & operator--()
Definition: span.h:161
XGBOOST_DEVICE SpanIterator operator--(int)
Definition: span.h:167
XGBOOST_DEVICE constexpr SpanIterator()
Definition: span.h:126
XGBOOST_DEVICE constexpr SpanIterator(const SpanIterator< SpanType, B > &other_) __span_noexcept
Definition: span.h:135
XGBOOST_DEVICE constexpr bool operator>(Span< T, X > l, Span< U, Y > r)
Definition: span.h:601
XGBOOST_DEVICE constexpr iterator end() const __span_noexcept
Definition: span.h:460
detail::ptrdiff_t index_type
Definition: span.h:387
XGBOOST_DEVICE constexpr Span() __span_noexcept
Definition: span.h:399
XGBOOST_DEVICE SpanIterator operator-(difference_type n) const
Definition: span.h:189
int64_t ptrdiff_t
Definition: span.h:97
XGBOOST_DEVICE SpanIterator & operator+=(difference_type n)
Definition: span.h:178
XGBOOST_DEVICE constexpr friend bool operator<=(SpanIterator _lhs, SpanIterator _rhs) __span_noexcept
Definition: span.h:214
Definition: span.h:281
XGBOOST_DEVICE SpanIterator operator+(difference_type n) const
Definition: span.h:173
T element_type
Definition: span.h:385
Definition: span.h:288
XGBOOST_DEVICE auto as_bytes(Span< T, E > s) __span_noexcept -> Span< const byte, detail::ExtentAsBytesValue< T, E >::value >
Definition: span.h:614
detail::ptrdiff_t difference_type
Definition: span.h:388
XGBOOST_DEVICE constexpr Span(const Span &_other) __span_noexcept
Definition: span.h:445
T & reference
Definition: span.h:390
XGBOOST_DEVICE auto as_writable_bytes(Span< T, E > s) __span_noexcept -> Span< byte, detail::ExtentAsBytesValue< T, E >::value >
Definition: span.h:620
XGBOOST_DEVICE constexpr friend bool operator!=(SpanIterator _lhs, SpanIterator _rhs) __span_noexcept
Definition: span.h:204
XGBOOST_DEVICE constexpr bool operator()(const T &_x, const T &_y) const
Definition: span.h:289
Definition: span.h:277
XGBOOST_DEVICE Span(Container &_cont)
Definition: span.h:425
std::random_access_iterator_tag iterator_category
Definition: span.h:118
span class implementation, based on ISO++20 span<T>. The interface should be the same.
Definition: span.h:109
XGBOOST_DEVICE SpanIterator & operator++()
Definition: span.h:149
typename std::conditional< IsConst, const ElementType, ElementType >::type & reference
Definition: span.h:123
XGBOOST_DEVICE SpanIterator & operator-=(difference_type n)
Definition: span.h:194
XGBOOST_DEVICE bool LexicographicalCompare(InputIt1 first1, InputIt1 last1, InputIt2 first2, InputIt2 last2)
Definition: span.h:297
XGBOOST_DEVICE constexpr friend bool operator>(SpanIterator _lhs, SpanIterator _rhs) __span_noexcept
Definition: span.h:219
typename std::remove_cv< T >::type value_type
Definition: span.h:386
constexpr detail::ptrdiff_t dynamic_extent
Definition: span.h:103
XGBOOST_DEVICE Span< element_type, Count > last() const
Definition: span.h:527
XGBOOST_DEVICE constexpr pointer data() const __span_noexcept
Definition: span.h:497
XGBOOST_DEVICE constexpr const_reverse_iterator crend() const __span_noexcept
Definition: span.h:484
XGBOOST_DEVICE constexpr reference operator()(index_type _idx) const
Definition: span.h:493
XGBOOST_DEVICE Span(pointer _ptr, index_type _count)
Definition: span.h:401
XGBOOST_DEVICE constexpr SpanIterator(const SpanType *_span, typename SpanType::index_type _idx) __span_noexcept
Definition: span.h:128
XGBOOST_DEVICE constexpr reverse_iterator rbegin() const __span_noexcept
Definition: span.h:472
XGBOOST_DEVICE constexpr bool operator()(const T &_x, const T &_y) const
Definition: span.h:282
XGBOOST_DEVICE Span< element_type, dynamic_extent > subspan(detail::ptrdiff_t _offset, detail::ptrdiff_t _count=dynamic_extent) const
Definition: span.h:554
XGBOOST_DEVICE constexpr bool operator!=(Span< T, X > l, Span< U, Y > r)
Definition: span.h:585
#define XGBOOST_DEVICE
Tag function as usable by device.
Definition: base.h:75
XGBOOST_DEVICE Span< element_type, dynamic_extent > last(detail::ptrdiff_t _count) const
Definition: span.h:532
typename SpanType::index_type difference_type
Definition: span.h:120
namespace of xgboost
Definition: base.h:79
XGBOOST_DEVICE constexpr reverse_iterator rend() const __span_noexcept
Definition: span.h:476
XGBOOST_DEVICE Span< element_type, dynamic_extent > first(detail::ptrdiff_t _count) const
Definition: span.h:520
XGBOOST_DEVICE constexpr const_reverse_iterator crbegin() const __span_noexcept
Definition: span.h:480
const SpanType * span_
Definition: span.h:230
XGBOOST_DEVICE auto subspan() const -> Span< element_type, detail::ExtentValue< Extent, Offset, Count >::value >
Definition: span.h:544
XGBOOST_DEVICE constexpr Span(element_type(&arr)[N]) __span_noexcept
Definition: span.h:414
XGBOOST_DEVICE Span(const Container &_cont)
Definition: span.h:435
XGBOOST_DEVICE constexpr const_iterator cbegin() const __span_noexcept
Definition: span.h:464
XGBOOST_DEVICE SpanIterator operator++(int)
Definition: span.h:155
T * pointer
Definition: span.h:389
#define SPAN_CHECK
Definition: span.h:87
XGBOOST_DEVICE reference operator[](index_type _idx) const
Definition: span.h:488
XGBOOST_DEVICE constexpr friend bool operator<(SpanIterator _lhs, SpanIterator _rhs) __span_noexcept
Definition: span.h:209
XGBOOST_DEVICE constexpr index_type size_bytes() const __span_noexcept
Definition: span.h:505
XGBOOST_DEVICE constexpr const_iterator cend() const __span_noexcept
Definition: span.h:468
#define __span_noexcept
span class based on ISO++20 span
Definition: span.h:63
XGBOOST_DEVICE Span & operator=(const Span &_other) __span_noexcept
Definition: span.h:448
XGBOOST_DEVICE constexpr friend bool operator==(SpanIterator _lhs, SpanIterator _rhs) __span_noexcept
Definition: span.h:199
typename std::add_pointer< reference >::type pointer
Definition: span.h:124
XGBOOST_DEVICE constexpr Span(const Span< U, OtherExtent > &_other) __span_noexcept
Definition: span.h:442
XGBOOST_DEVICE constexpr iterator begin() const __span_noexcept
Definition: span.h:456
XGBOOST_DEVICE difference_type operator-(SpanIterator rhs) const
Definition: span.h:184
XGBOOST_DEVICE ~Span() __span_noexcept
Definition: span.h:454
XGBOOST_DEVICE reference operator*() const
Definition: span.h:139
XGBOOST_DEVICE pointer operator->() const
Definition: span.h:144
XGBOOST_DEVICE bool operator==(Span< T, X > l, Span< U, Y > r)
Definition: span.h:571