4 #ifndef XGBOOST_COMMON_TRANSFORM_H_
5 #define XGBOOST_COMMON_TRANSFORM_H_
8 #include <dmlc/common.h>
13 #include <type_traits>
21 #if defined (__CUDACC__)
22 #include "device_helpers.cuh"
23 #endif // defined (__CUDACC__)
32 #if defined(__CUDACC__)
33 template <
typename Functor,
typename... SpanType>
34 __global__
void LaunchCUDAKernel(Functor _func,
Range _range,
36 for (
auto i : dh::GridStrideRange(*_range.
begin(), *_range.
end())) {
40 #endif // defined(__CUDACC__)
58 template <
bool CompiledWithCuda = WITH_CUDA()>
61 template <
typename Functor>
64 Evaluator(Functor func,
Range range,
int device,
bool shard) :
65 func_(func), range_{std::move(range)},
75 template <
typename... HDV>
76 void Eval(HDV... vectors)
const {
77 bool on_device = device_ >= 0;
80 LaunchCUDA(func_, vectors...);
82 LaunchCPU(func_, vectors...);
104 template <
typename T>
110 template <
typename T>
114 template <
typename Head,
typename... Rest>
118 SyncHost(_vectors...);
121 template <
typename T>
125 template <
typename Head,
typename... Rest>
126 void UnpackShard(
int device,
130 UnpackShard(device, _vectors...);
133 #if defined(__CUDACC__)
134 template <typename std::enable_if<CompiledWithCuda>::type* =
nullptr,
136 void LaunchCUDA(Functor _func, HDV*... _vectors)
const {
138 UnpackShard(device_, _vectors...);
141 size_t range_size = *range_.end() - *range_.begin();
146 size_t shard_size = range_size;
148 dh::safe_cuda(cudaSetDevice(device_));
154 detail::LaunchCUDAKernel<<<kGrids, kBlockThreads>>>(
155 _func, shard_range, UnpackHDVOnDevice(_vectors)...);
159 template <typename std::enable_if<!CompiledWithCuda>::type* =
nullptr,
161 void LaunchCUDA(Functor _func, HDV*...)
const {
165 LOG(FATAL) <<
"Not part of device code. WITH_CUDA: " <<
WITH_CUDA();
167 #endif // defined(__CUDACC__)
169 template <
typename... HDV>
170 void LaunchCPU(Functor func, HDV*... vectors)
const {
172 SyncHost(vectors...);
174 func(idx, UnpackHDV(vectors)...);
201 template <
typename Functor>
202 static Evaluator<Functor>
Init(Functor func,
Range const range,
204 bool const shard =
true) {
205 return Evaluator<Functor> {func, std::move(range), device, shard};
212 #endif // XGBOOST_COMMON_TRANSFORM_H_