xgboost
transform.h
Go to the documentation of this file.
1 
4 #ifndef XGBOOST_COMMON_TRANSFORM_H_
5 #define XGBOOST_COMMON_TRANSFORM_H_
6 
7 #include <dmlc/omp.h>
8 #include <xgboost/data.h>
9 #include <utility>
10 #include <vector>
11 #include <type_traits> // enable_if
12 
13 #include "host_device_vector.h"
14 #include "common.h"
15 #include "span.h"
16 
17 #if defined (__CUDACC__)
18 #include "device_helpers.cuh"
19 #endif // defined (__CUDACC__)
20 
21 namespace xgboost {
22 namespace common {
23 
24 constexpr size_t kBlockThreads = 256;
25 
26 namespace detail {
27 
28 #if defined(__CUDACC__)
29 template <typename Functor, typename... SpanType>
30 __global__ void LaunchCUDAKernel(Functor _func, Range _range,
31  SpanType... _spans) {
32  for (auto i : dh::GridStrideRange(*_range.begin(), *_range.end())) {
33  _func(i, _spans...);
34  }
35 }
36 #endif // defined(__CUDACC__)
37 
38 } // namespace detail
39 
54 template <bool CompiledWithCuda = WITH_CUDA()>
55 class Transform {
56  private:
57  template <typename Functor>
58  struct Evaluator {
59  public:
60  Evaluator(Functor func, Range range, GPUSet devices, bool shard) :
61  func_(func), range_{std::move(range)},
62  shard_{shard},
63  distribution_{std::move(GPUDistribution::Block(devices))} {}
64  Evaluator(Functor func, Range range, GPUDistribution dist,
65  bool shard) :
66  func_(func), range_{std::move(range)}, shard_{shard},
67  distribution_{std::move(dist)} {}
68 
75  template <typename... HDV>
76  void Eval(HDV... vectors) const {
77  bool on_device = !distribution_.IsEmpty();
78 
79  if (on_device) {
80  LaunchCUDA(func_, vectors...);
81  } else {
82  LaunchCPU(func_, vectors...);
83  }
84  }
85 
86  private:
87  // CUDA UnpackHDV
88  template <typename T>
89  Span<T> UnpackHDV(HostDeviceVector<T>* _vec, int _device) const {
90  auto span = _vec->DeviceSpan(_device);
91  return span;
92  }
93  template <typename T>
94  Span<T const> UnpackHDV(const HostDeviceVector<T>* _vec, int _device) const {
95  auto span = _vec->ConstDeviceSpan(_device);
96  return span;
97  }
98  // CPU UnpackHDV
99  template <typename T>
100  Span<T> UnpackHDV(HostDeviceVector<T>* _vec) const {
101  return Span<T> {_vec->HostPointer(),
102  static_cast<typename Span<T>::index_type>(_vec->Size())};
103  }
104  template <typename T>
105  Span<T const> UnpackHDV(const HostDeviceVector<T>* _vec) const {
106  return Span<T const> {_vec->ConstHostPointer(),
107  static_cast<typename Span<T>::index_type>(_vec->Size())};
108  }
109  // Recursive unpack for Shard.
110  template <typename T>
111  void UnpackShard(GPUDistribution dist, const HostDeviceVector<T> *vector) const {
112  vector->Shard(dist);
113  }
114  template <typename Head, typename... Rest>
115  void UnpackShard(GPUDistribution dist,
116  const HostDeviceVector<Head> *_vector,
117  const HostDeviceVector<Rest> *... _vectors) const {
118  _vector->Shard(dist);
119  UnpackShard(dist, _vectors...);
120  }
121 
122 #if defined(__CUDACC__)
123  template <typename std::enable_if<CompiledWithCuda>::type* = nullptr,
124  typename... HDV>
125  void LaunchCUDA(Functor _func, HDV*... _vectors) const {
126  if (shard_)
127  UnpackShard(distribution_, _vectors...);
128 
129  GPUSet devices = distribution_.Devices();
130  size_t range_size = *range_.end() - *range_.begin();
131 
132  // Extract index to deal with possible old OpenMP.
133  size_t device_beg = *(devices.begin());
134  size_t device_end = *(devices.end());
135 #pragma omp parallel for schedule(static, 1) if (devices.Size() > 1)
136  for (omp_ulong device = device_beg; device < device_end; ++device) { // NOLINT
137  // Ignore other attributes of GPUDistribution for spliting index.
138  // This deals with situation like multi-class setting where
139  // granularity is used in data vector.
140  size_t shard_size = GPUDistribution::Block(devices).ShardSize(
141  range_size, devices.Index(device));
142  Range shard_range {0, static_cast<Range::DifferenceType>(shard_size)};
143  dh::safe_cuda(cudaSetDevice(device));
144  const int GRID_SIZE =
145  static_cast<int>(dh::DivRoundUp(*(range_.end()), kBlockThreads));
146  detail::LaunchCUDAKernel<<<GRID_SIZE, kBlockThreads>>>(
147  _func, shard_range, UnpackHDV(_vectors, device)...);
148  }
149  }
150 #else
151 
152  template <typename std::enable_if<!CompiledWithCuda>::type* = nullptr,
153  typename... HDV>
154  void LaunchCUDA(Functor _func, HDV*... _vectors) const {
155  LOG(FATAL) << "Not part of device code. WITH_CUDA: " << WITH_CUDA();
156  }
157 #endif // defined(__CUDACC__)
158 
159  template <typename... HDV>
160  void LaunchCPU(Functor func, HDV*... vectors) const {
161  omp_ulong end = static_cast<omp_ulong>(*(range_.end()));
162 #pragma omp parallel for schedule(static)
163  for (omp_ulong idx = 0; idx < end; ++idx) {
164  func(idx, UnpackHDV(vectors)...);
165  }
166  }
167 
168  private:
170  Functor func_;
172  Range range_;
174  bool shard_;
175  GPUDistribution distribution_;
176  };
177 
178  public:
192  template <typename Functor>
193  static Evaluator<Functor> Init(Functor func, Range const range,
194  GPUSet const devices,
195  bool const shard = true) {
196  return Evaluator<Functor> {func, std::move(range), std::move(devices), shard};
197  }
198  template <typename Functor>
199  static Evaluator<Functor> Init(Functor func, Range const range,
200  GPUDistribution const dist,
201  bool const shard = true) {
202  return Evaluator<Functor> {func, std::move(range), std::move(dist), shard};
203  }
204 };
205 
206 } // namespace common
207 } // namespace xgboost
208 
209 #endif // XGBOOST_COMMON_TRANSFORM_H_
common::Span< T > DeviceSpan(int device)
Definition: host_device_vector.h:87
size_t ShardSize(size_t size, size_t index) const
Definition: host_device_vector.h:149
XGBOOST_DEVICE Iterator begin() const
Definition: common.h:116
Definition: host_device_vector.h:200
Definition: common.h:148
detail::ptrdiff_t index_type
Definition: span.h:387
static Evaluator< Functor > Init(Functor func, Range const range, GPUDistribution const dist, bool const shard=true)
Definition: transform.h:199
The input data structure of xgboost.
T * HostPointer()
Definition: host_device_vector.h:221
dmlc::omp_ulong omp_ulong
define unsigned long for openmp loop
Definition: base.h:206
Do Transformation on HostDeviceVectors.
Definition: transform.h:55
Definition: common.h:78
int64_t DifferenceType
Definition: common.h:80
A device-and-host vector abstraction layer.
XGBOOST_DEVICE Iterator end() const
Definition: common.h:117
const T * ConstHostPointer() const
Definition: host_device_vector.h:222
span class implementation, based on ISO++20 span<T>. The interface should be the same.
Definition: span.h:109
common::Range::Iterator end() const
Definition: common.h:234
#define WITH_CUDA()
Definition: common.h:27
common::Range::Iterator begin() const
Definition: common.h:233
static GPUDistribution Block(GPUSet devices)
Definition: host_device_vector.h:103
constexpr size_t kBlockThreads
Definition: transform.h:24
common::Span< const T > ConstDeviceSpan(int device) const
namespace of xgboost
Definition: base.h:79
size_t Index(GpuIdType device) const
Definition: common.h:219
Common utilities.
void Shard(const GPUDistribution &distribution) const
Specify memory distribution.
static Evaluator< Functor > Init(Functor func, Range const range, GPUSet const devices, bool const shard=true)
Initialize a Transform object.
Definition: transform.h:193