xgboost
span.h
Go to the documentation of this file.
1 
29 #ifndef XGBOOST_SPAN_H_
30 #define XGBOOST_SPAN_H_
31 
32 #include <xgboost/logging.h> // CHECK
33 
34 #include <cinttypes> // size_t
35 #include <numeric> // numeric_limits
36 #include <type_traits>
37 
55 #if defined(_MSC_VER) && _MSC_VER < 1910
56 
57 #define __span_noexcept
58 
59 #pragma push_macro("constexpr")
60 #define constexpr /*constexpr*/
61 
62 #else
63 
64 #define __span_noexcept noexcept
65 
66 #endif // defined(_MSC_VER) && _MSC_VER < 1910
67 
68 namespace xgboost {
69 namespace common {
70 
71 // Usual logging facility is not available inside device code.
72 // TODO(trivialfis): Make dmlc check more generic.
73 // assert is not supported in mac as of CUDA 10.0
74 #define KERNEL_CHECK(cond) \
75  do { \
76  if (!(cond)) { \
77  printf("\nKernel error:\n" \
78  "In: %s: %d\n" \
79  "\t%s\n\tExpecting: %s\n" \
80  "\tBlock: [%d, %d, %d], Thread: [%d, %d, %d]\n\n", \
81  __FILE__, __LINE__, __PRETTY_FUNCTION__, #cond, \
82  blockIdx.x, blockIdx.y, blockIdx.z, \
83  threadIdx.x, threadIdx.y, threadIdx.z); \
84  asm("trap;"); \
85  } \
86  } while (0);
87 
88 #ifdef __CUDA_ARCH__
89 #define SPAN_CHECK KERNEL_CHECK
90 #else
91 #define SPAN_CHECK CHECK // check from dmlc
92 #endif // __CUDA_ARCH__
93 
94 namespace detail {
101 using ptrdiff_t = typename std::conditional<std::is_same<std::ptrdiff_t, std::int64_t>::value,
102  std::ptrdiff_t, std::int64_t>::type;
103 } // namespace detail
104 
105 #if defined(_MSC_VER) && _MSC_VER < 1910
106 constexpr const std::size_t
107 dynamic_extent = std::numeric_limits<std::size_t>::max(); // NOLINT
108 #else
109 constexpr std::size_t dynamic_extent = std::numeric_limits<std::size_t>::max(); // NOLINT
110 #endif // defined(_MSC_VER) && _MSC_VER < 1910
111 
112 enum class byte : unsigned char {}; // NOLINT
113 
114 template <class ElementType, std::size_t Extent>
115 class Span;
116 
117 namespace detail {
118 
119 template <typename SpanType, bool IsConst>
121  using ElementType = typename SpanType::element_type;
122 
123  public:
124  using iterator_category = std::random_access_iterator_tag; // NOLINT
125  using value_type = typename SpanType::value_type; // NOLINT
127 
128  using reference = typename std::conditional< // NOLINT
129  IsConst, const ElementType, ElementType>::type&;
130  using pointer = typename std::add_pointer<reference>::type; // NOLINT
131 
132  XGBOOST_DEVICE constexpr SpanIterator() : span_{nullptr}, index_{0} {}
133 
135  const SpanType* _span,
136  typename SpanType::index_type _idx) __span_noexcept :
137  span_(_span), index_(_idx) {}
138 
140  template <bool B, typename std::enable_if<!B && IsConst>::type* = nullptr>
141  XGBOOST_DEVICE constexpr SpanIterator( // NOLINT
143  : SpanIterator(other_.span_, other_.index_) {}
144 
146  SPAN_CHECK(index_ < span_->size());
147  return *(span_->data() + index_);
148  }
150  return *(*this + n);
151  }
152 
154  SPAN_CHECK(index_ != span_->size());
155  return span_->data() + index_;
156  }
157 
159  SPAN_CHECK(index_ != span_->size());
160  index_++;
161  return *this;
162  }
163 
165  auto ret = *this;
166  ++(*this);
167  return ret;
168  }
169 
171  SPAN_CHECK(index_ != 0 && index_ <= span_->size());
172  index_--;
173  return *this;
174  }
175 
177  auto ret = *this;
178  --(*this);
179  return ret;
180  }
181 
183  auto ret = *this;
184  return ret += n;
185  }
186 
188  SPAN_CHECK((index_ + n) <= span_->size());
189  index_ += n;
190  return *this;
191  }
192 
194  SPAN_CHECK(span_ == rhs.span_);
195  return index_ - rhs.index_;
196  }
197 
199  auto ret = *this;
200  return ret -= n;
201  }
202 
204  return *this += -n;
205  }
206 
207  // friends
208  XGBOOST_DEVICE constexpr friend bool operator==(
210  return _lhs.span_ == _rhs.span_ && _lhs.index_ == _rhs.index_;
211  }
212 
213  XGBOOST_DEVICE constexpr friend bool operator!=(
215  return !(_lhs == _rhs);
216  }
217 
218  XGBOOST_DEVICE constexpr friend bool operator<(
220  return _lhs.index_ < _rhs.index_;
221  }
222 
223  XGBOOST_DEVICE constexpr friend bool operator<=(
225  return !(_rhs < _lhs);
226  }
227 
228  XGBOOST_DEVICE constexpr friend bool operator>(
230  return _rhs < _lhs;
231  }
232 
233  XGBOOST_DEVICE constexpr friend bool operator>=(
235  return !(_rhs > _lhs);
236  }
237 
238  protected:
239  const SpanType *span_;
240  typename SpanType::index_type index_;
241 };
242 
243 
244 // It's tempting to use constexpr instead of structs to do the following meta
245 // programming. But remember that we are supporting MSVC 2013 here.
246 
254 template <std::size_t Extent, std::size_t Offset, std::size_t Count>
255 struct ExtentValue : public std::integral_constant<
256  std::size_t, Count != dynamic_extent ?
257  Count : (Extent != dynamic_extent ? Extent - Offset : Extent)> {};
258 
263 template <typename T, std::size_t Extent>
264 struct ExtentAsBytesValue : public std::integral_constant<
265  std::size_t,
266  Extent == dynamic_extent ?
267  Extent : sizeof(T) * Extent> {};
268 
269 template <std::size_t From, std::size_t To>
270 struct IsAllowedExtentConversion : public std::integral_constant<
271  bool, From == To || From == dynamic_extent || To == dynamic_extent> {};
272 
273 template <class From, class To>
274 struct IsAllowedElementTypeConversion : public std::integral_constant<
275  bool, std::is_convertible<From(*)[], To(*)[]>::value> {};
276 
277 template <class T>
278 struct IsSpanOracle : std::false_type {};
279 
280 template <class T, std::size_t Extent>
281 struct IsSpanOracle<Span<T, Extent>> : std::true_type {};
282 
283 template <class T>
284 struct IsSpan : public IsSpanOracle<typename std::remove_cv<T>::type> {};
285 
286 // Re-implement std algorithms here to adopt CUDA.
287 template <typename T>
288 struct Less {
289  XGBOOST_DEVICE constexpr bool operator()(const T& _x, const T& _y) const {
290  return _x < _y;
291  }
292 };
293 
294 template <typename T>
295 struct Greater {
296  XGBOOST_DEVICE constexpr bool operator()(const T& _x, const T& _y) const {
297  return _x > _y;
298  }
299 };
300 
301 template <class InputIt1, class InputIt2,
302  class Compare =
304 XGBOOST_DEVICE bool LexicographicalCompare(InputIt1 first1, InputIt1 last1,
305  InputIt2 first2, InputIt2 last2) {
306  Compare comp;
307  for (; first1 != last1 && first2 != last2; ++first1, ++first2) {
308  if (comp(*first1, *first2)) {
309  return true;
310  }
311  if (comp(*first2, *first1)) {
312  return false;
313  }
314  }
315  return first1 == last1 && first2 != last2;
316 }
317 
318 } // namespace detail
319 
320 
388 template <typename T,
389  std::size_t Extent = dynamic_extent>
390 class Span {
391  public:
392  using element_type = T; // NOLINT
393  using value_type = typename std::remove_cv<T>::type; // NOLINT
394  using index_type = std::size_t; // NOLINT
396  using pointer = T*; // NOLINT
397  using reference = T&; // NOLINT
398 
400  using const_iterator = const detail::SpanIterator<Span<T, Extent>, true>; // NOLINT
403 
404  // constructors
405 
406  XGBOOST_DEVICE constexpr Span() __span_noexcept : size_(0), data_(nullptr) {}
407 
409  size_(_count), data_(_ptr) {
410  SPAN_CHECK(!(Extent != dynamic_extent && _count != Extent));
411  SPAN_CHECK(_ptr || _count == 0);
412  }
413 
415  size_(_last - _first), data_(_first) {
416  SPAN_CHECK(data_ || size_ == 0);
417  }
418 
419  template <std::size_t N>
420  XGBOOST_DEVICE constexpr Span(element_type (&arr)[N]) // NOLINT
421  __span_noexcept : size_(N), data_(&arr[0]) {}
422 
423  template <class Container,
424  class = typename std::enable_if<
425  !std::is_const<element_type>::value &&
427  std::is_convertible<typename Container::pointer, pointer>::value &&
428  std::is_convertible<typename Container::pointer,
429  decltype(std::declval<Container>().data())>::value>::type>
430  Span(Container& _cont) : // NOLINT
431  size_(_cont.size()), data_(_cont.data()) {
432  static_assert(!detail::IsSpan<Container>::value, "Wrong constructor of Span is called.");
433  }
434 
435  template <class Container,
436  class = typename std::enable_if<
437  std::is_const<element_type>::value &&
439  std::is_convertible<typename Container::pointer, pointer>::value &&
440  std::is_convertible<typename Container::pointer,
441  decltype(std::declval<Container>().data())>::value>::type>
442  Span(const Container& _cont) : size_(_cont.size()), // NOLINT
443  data_(_cont.data()) {
444  static_assert(!detail::IsSpan<Container>::value, "Wrong constructor of Span is called.");
445  }
446 
447  template <class U, std::size_t OtherExtent,
448  class = typename std::enable_if<
451  XGBOOST_DEVICE constexpr Span(const Span<U, OtherExtent>& _other) // NOLINT
452  __span_noexcept : size_(_other.size()), data_(_other.data()) {}
453 
454  XGBOOST_DEVICE constexpr Span(const Span& _other)
455  __span_noexcept : size_(_other.size()), data_(_other.data()) {}
456 
458  size_ = _other.size();
459  data_ = _other.data();
460  return *this;
461  }
462 
464 
465  XGBOOST_DEVICE constexpr iterator begin() const __span_noexcept { // NOLINT
466  return {this, 0};
467  }
468 
469  XGBOOST_DEVICE constexpr iterator end() const __span_noexcept { // NOLINT
470  return {this, size()};
471  }
472 
473  XGBOOST_DEVICE constexpr const_iterator cbegin() const __span_noexcept { // NOLINT
474  return {this, 0};
475  }
476 
477  XGBOOST_DEVICE constexpr const_iterator cend() const __span_noexcept { // NOLINT
478  return {this, size()};
479  }
480 
482  return reverse_iterator{end()};
483  }
484 
485  XGBOOST_DEVICE constexpr reverse_iterator rend() const __span_noexcept { // NOLINT
486  return reverse_iterator{begin()};
487  }
488 
490  return const_reverse_iterator{cend()};
491  }
492 
494  return const_reverse_iterator{cbegin()};
495  }
496 
497  // element access
498 
500  return (*this)[0];
501  }
502 
504  return (*this)[size() - 1];
505  }
506 
508  SPAN_CHECK(_idx < size());
509  return data()[_idx];
510  }
511 
513  return this->operator[](_idx);
514  }
515 
516  XGBOOST_DEVICE constexpr pointer data() const __span_noexcept { // NOLINT
517  return data_;
518  }
519 
520  // Observers
521  XGBOOST_DEVICE constexpr index_type size() const __span_noexcept { // NOLINT
522  return size_;
523  }
524  XGBOOST_DEVICE constexpr index_type size_bytes() const __span_noexcept { // NOLINT
525  return size() * sizeof(T);
526  }
527 
528  XGBOOST_DEVICE constexpr bool empty() const __span_noexcept { // NOLINT
529  return size() == 0;
530  }
531 
532  // Subviews
533  template <std::size_t Count>
535  SPAN_CHECK(Count <= size());
536  return {data(), Count};
537  }
538 
540  std::size_t _count) const {
541  SPAN_CHECK(_count <= size());
542  return {data(), _count};
543  }
544 
545  template <std::size_t Count>
547  SPAN_CHECK(Count <= size());
548  return {data() + size() - Count, Count};
549  }
550 
552  std::size_t _count) const {
553  SPAN_CHECK(_count <= size());
554  return subspan(size() - _count, _count);
555  }
556 
561  template <std::size_t Offset,
562  std::size_t Count = dynamic_extent>
563  XGBOOST_DEVICE auto subspan() const -> // NOLINT
565  detail::ExtentValue<Extent, Offset, Count>::value> {
566  SPAN_CHECK(Offset < size() || size() == 0);
567  SPAN_CHECK(Count == dynamic_extent || (Offset + Count <= size()));
568 
569  return {data() + Offset, Count == dynamic_extent ? size() - Offset : Count};
570  }
571 
573  index_type _offset,
574  index_type _count = dynamic_extent) const {
575  SPAN_CHECK(_offset < size() || size() == 0);
576  SPAN_CHECK((_count == dynamic_extent) || (_offset + _count <= size()));
577 
578  return {data() + _offset, _count ==
579  dynamic_extent ? size() - _offset : _count};
580  }
581 
582  private:
583  index_type size_;
584  pointer data_;
585 };
586 
587 template <class T, std::size_t X, class U, std::size_t Y>
589  if (l.size() != r.size()) {
590  return false;
591  }
592  for (auto l_beg = l.cbegin(), r_beg = r.cbegin(); l_beg != l.cend();
593  ++l_beg, ++r_beg) {
594  if (*l_beg != *r_beg) {
595  return false;
596  }
597  }
598  return true;
599 }
600 
601 template <class T, std::size_t X, class U, std::size_t Y>
603  return !(l == r);
604 }
605 
606 template <class T, std::size_t X, class U, std::size_t Y>
607 XGBOOST_DEVICE constexpr bool operator<(Span<T, X> l, Span<U, Y> r) {
608  return detail::LexicographicalCompare(l.begin(), l.end(),
609  r.begin(), r.end());
610 }
611 
612 template <class T, std::size_t X, class U, std::size_t Y>
613 XGBOOST_DEVICE constexpr bool operator<=(Span<T, X> l, Span<U, Y> r) {
614  return !(l > r);
615 }
616 
617 template <class T, std::size_t X, class U, std::size_t Y>
620  typename Span<T, X>::iterator, typename Span<U, Y>::iterator,
622  r.begin(), r.end());
623 }
624 
625 template <class T, std::size_t X, class U, std::size_t Y>
627  return !(l < r);
628 }
629 
630 template <class T, std::size_t E>
633  return {reinterpret_cast<const byte*>(s.data()), s.size_bytes()};
634 }
635 
636 template <class T, std::size_t E>
639  return {reinterpret_cast<byte*>(s.data()), s.size_bytes()};
640 }
641 
642 } // namespace common
643 } // namespace xgboost
644 
645 #if defined(_MSC_VER) &&_MSC_VER < 1910
646 #undef constexpr
647 #pragma pop_macro("constexpr")
648 #undef __span_noexcept
649 #endif // _MSC_VER < 1910
650 
651 #endif // XGBOOST_SPAN_H_
byte
Definition: span.h:112
XGBOOST_DEVICE constexpr bool operator>=(Span< T, X > l, Span< U, Y > r)
Definition: span.h:626
XGBOOST_DEVICE constexpr friend bool operator>=(SpanIterator _lhs, SpanIterator _rhs) __span_noexcept
Definition: span.h:233
XGBOOST_DEVICE Span< element_type, Count > first() const
Definition: span.h:534
XGBOOST_DEVICE constexpr index_type size() const __span_noexcept
Definition: span.h:521
XGBOOST_DEVICE Span(pointer _first, pointer _last)
Definition: span.h:414
XGBOOST_DEVICE constexpr bool empty() const __span_noexcept
Definition: span.h:528
XGBOOST_DEVICE reference operator()(index_type _idx) const
Definition: span.h:512
XGBOOST_DEVICE SpanIterator & operator--()
Definition: span.h:170
XGBOOST_DEVICE SpanIterator operator--(int)
Definition: span.h:176
XGBOOST_DEVICE constexpr SpanIterator()
Definition: span.h:132
std::size_t index_type
Definition: span.h:394
XGBOOST_DEVICE Span< element_type, dynamic_extent > subspan(index_type _offset, index_type _count=dynamic_extent) const
Definition: span.h:572
XGBOOST_DEVICE constexpr SpanIterator(const SpanIterator< SpanType, B > &other_) __span_noexcept
Definition: span.h:141
XGBOOST_DEVICE constexpr bool operator>(Span< T, X > l, Span< U, Y > r)
Definition: span.h:618
XGBOOST_DEVICE constexpr iterator end() const __span_noexcept
Definition: span.h:469
XGBOOST_DEVICE constexpr Span() __span_noexcept
Definition: span.h:406
XGBOOST_DEVICE SpanIterator operator-(difference_type n) const
Definition: span.h:198
XGBOOST_DEVICE SpanIterator & operator+=(difference_type n)
Definition: span.h:187
Span(const Container &_cont)
Definition: span.h:442
XGBOOST_DEVICE constexpr friend bool operator<=(SpanIterator _lhs, SpanIterator _rhs) __span_noexcept
Definition: span.h:223
Definition: span.h:288
XGBOOST_DEVICE SpanIterator operator+(difference_type n) const
Definition: span.h:182
value_type element_type
Definition: span.h:392
Definition: span.h:295
XGBOOST_DEVICE auto as_bytes(Span< T, E > s) __span_noexcept -> Span< const byte, detail::ExtentAsBytesValue< T, E >::value >
Definition: span.h:631
Span(Container &_cont)
Definition: span.h:430
XGBOOST_DEVICE reference operator[](difference_type n) const
Definition: span.h:149
detail::ptrdiff_t difference_type
Definition: span.h:395
typename SpanType::value_type value_type
Definition: span.h:125
XGBOOST_DEVICE reference back() const
Definition: span.h:503
XGBOOST_DEVICE constexpr Span(const Span &_other) __span_noexcept
Definition: span.h:454
value_type & reference
Definition: span.h:397
XGBOOST_DEVICE auto as_writable_bytes(Span< T, E > s) __span_noexcept -> Span< byte, detail::ExtentAsBytesValue< T, E >::value >
Definition: span.h:637
XGBOOST_DEVICE constexpr friend bool operator!=(SpanIterator _lhs, SpanIterator _rhs) __span_noexcept
Definition: span.h:213
XGBOOST_DEVICE constexpr bool operator()(const T &_x, const T &_y) const
Definition: span.h:296
Definition: span.h:284
std::random_access_iterator_tag iterator_category
Definition: span.h:124
XGBOOST_DEVICE Span< element_type, dynamic_extent > first(std::size_t _count) const
Definition: span.h:539
span class implementation, based on ISO++20 span<T>. The interface should be the same.
Definition: span.h:115
XGBOOST_DEVICE SpanIterator & operator++()
Definition: span.h:158
typename std::conditional< IsConst, const ElementType, ElementType >::type & reference
Definition: span.h:129
XGBOOST_DEVICE SpanIterator & operator-=(difference_type n)
Definition: span.h:203
XGBOOST_DEVICE bool LexicographicalCompare(InputIt1 first1, InputIt1 last1, InputIt2 first2, InputIt2 last2)
Definition: span.h:304
XGBOOST_DEVICE constexpr friend bool operator>(SpanIterator _lhs, SpanIterator _rhs) __span_noexcept
Definition: span.h:228
typename std::remove_cv< value_type >::type value_type
Definition: span.h:393
XGBOOST_DEVICE Span< element_type, Count > last() const
Definition: span.h:546
XGBOOST_DEVICE constexpr pointer data() const __span_noexcept
Definition: span.h:516
XGBOOST_DEVICE constexpr const_reverse_iterator crend() const __span_noexcept
Definition: span.h:493
XGBOOST_DEVICE Span(pointer _ptr, index_type _count)
Definition: span.h:408
XGBOOST_DEVICE constexpr SpanIterator(const SpanType *_span, typename SpanType::index_type _idx) __span_noexcept
Definition: span.h:134
constexpr std::size_t dynamic_extent
Definition: span.h:109
XGBOOST_DEVICE constexpr reverse_iterator rbegin() const __span_noexcept
Definition: span.h:481
XGBOOST_DEVICE constexpr bool operator()(const T &_x, const T &_y) const
Definition: span.h:289
XGBOOST_DEVICE constexpr bool operator!=(Span< T, X > l, Span< U, Y > r)
Definition: span.h:602
#define XGBOOST_DEVICE
Tag function as usable by device.
Definition: base.h:84
SpanType::index_type index_
Definition: span.h:240
namespace of xgboost
Definition: base.h:102
XGBOOST_DEVICE constexpr reverse_iterator rend() const __span_noexcept
Definition: span.h:485
XGBOOST_DEVICE Span< element_type, dynamic_extent > last(std::size_t _count) const
Definition: span.h:551
typename std::conditional< std::is_same< std::ptrdiff_t, std::int64_t >::value, std::ptrdiff_t, std::int64_t >::type ptrdiff_t
Definition: span.h:102
XGBOOST_DEVICE reference front() const
Definition: span.h:499
XGBOOST_DEVICE constexpr const_reverse_iterator crbegin() const __span_noexcept
Definition: span.h:489
const SpanType * span_
Definition: span.h:239
XGBOOST_DEVICE auto subspan() const -> Span< element_type, detail::ExtentValue< Extent, Offset, Count >::value >
Definition: span.h:563
XGBOOST_DEVICE constexpr Span(element_type(&arr)[N]) __span_noexcept
Definition: span.h:420
XGBOOST_DEVICE constexpr const_iterator cbegin() const __span_noexcept
Definition: span.h:473
XGBOOST_DEVICE SpanIterator operator++(int)
Definition: span.h:164
value_type * pointer
Definition: span.h:396
#define SPAN_CHECK
Definition: span.h:91
XGBOOST_DEVICE reference operator[](index_type _idx) const
Definition: span.h:507
XGBOOST_DEVICE constexpr friend bool operator<(SpanIterator _lhs, SpanIterator _rhs) __span_noexcept
Definition: span.h:218
XGBOOST_DEVICE constexpr index_type size_bytes() const __span_noexcept
Definition: span.h:524
XGBOOST_DEVICE constexpr const_iterator cend() const __span_noexcept
Definition: span.h:477
#define __span_noexcept
span class based on ISO++20 span
Definition: span.h:64
XGBOOST_DEVICE Span & operator=(const Span &_other) __span_noexcept
Definition: span.h:457
XGBOOST_DEVICE constexpr friend bool operator==(SpanIterator _lhs, SpanIterator _rhs) __span_noexcept
Definition: span.h:208
typename std::add_pointer< reference >::type pointer
Definition: span.h:130
XGBOOST_DEVICE constexpr Span(const Span< U, OtherExtent > &_other) __span_noexcept
Definition: span.h:451
detail::ptrdiff_t difference_type
Definition: span.h:126
XGBOOST_DEVICE constexpr iterator begin() const __span_noexcept
Definition: span.h:465
XGBOOST_DEVICE difference_type operator-(SpanIterator rhs) const
Definition: span.h:193
XGBOOST_DEVICE ~Span() __span_noexcept
Definition: span.h:463
XGBOOST_DEVICE reference operator*() const
Definition: span.h:145
XGBOOST_DEVICE pointer operator->() const
Definition: span.h:153
XGBOOST_DEVICE bool operator==(Span< T, X > l, Span< U, Y > r)
Definition: span.h:588