4 #ifndef XGBOOST_COMMON_TRANSFORM_H_
5 #define XGBOOST_COMMON_TRANSFORM_H_
7 #include <dmlc/common.h>
11 #include <type_traits>
20 #if defined (__CUDACC__)
21 #include "device_helpers.cuh"
22 #endif // defined (__CUDACC__)
31 #if defined(__CUDACC__)
32 template <
typename Functor,
typename... SpanType>
33 __global__
void LaunchCUDAKernel(Functor _func,
Range _range,
35 for (
auto i : dh::GridStrideRange(*_range.
begin(), *_range.
end())) {
39 #endif // defined(__CUDACC__)
57 template <
bool CompiledWithCuda = WITH_CUDA()>
60 template <
typename Functor>
63 Evaluator(Functor func,
Range range, int32_t n_threads, int32_t device_idx)
64 : func_(func), range_{std::move(range)}, n_threads_{n_threads}, device_{device_idx} {}
72 template <
typename... HDV>
73 void Eval(HDV... vectors)
const {
74 bool on_device = device_ >= 0;
77 LaunchCUDA(func_, vectors...);
79 LaunchCPU(func_, vectors...);
101 template <
typename T>
107 template <
typename T>
111 template <
typename Head,
typename... Rest>
115 SyncHost(_vectors...);
118 template <
typename T>
122 template <
typename Head,
typename... Rest>
123 void UnpackShard(
int device,
127 UnpackShard(device, _vectors...);
130 #if defined(__CUDACC__)
131 template <typename std::enable_if<CompiledWithCuda>::type* =
nullptr,
133 void LaunchCUDA(Functor _func, HDV*... _vectors)
const {
134 UnpackShard(device_, _vectors...);
136 size_t range_size = *range_.end() - *range_.begin();
141 size_t shard_size = range_size;
143 dh::safe_cuda(cudaSetDevice(device_));
149 detail::LaunchCUDAKernel<<<kGrids, kBlockThreads>>>(
150 _func, shard_range, UnpackHDVOnDevice(_vectors)...);
154 template <typename std::enable_if<!CompiledWithCuda>::type* =
nullptr,
156 void LaunchCUDA(Functor _func, HDV*...)
const {
160 LOG(FATAL) <<
"Not part of device code. WITH_CUDA: " <<
WITH_CUDA();
162 #endif // defined(__CUDACC__)
164 template <
typename... HDV>
165 void LaunchCPU(Functor func, HDV *...vectors)
const {
167 SyncHost(vectors...);
193 template <
typename Functor>
194 static Evaluator<Functor>
Init(Functor func,
Range const range, int32_t n_threads,
195 int32_t device_idx) {
196 return Evaluator<Functor>{func, std::move(range), n_threads, device_idx};
203 #endif // XGBOOST_COMMON_TRANSFORM_H_