4 #ifndef XGBOOST_COMMON_TRANSFORM_H_
5 #define XGBOOST_COMMON_TRANSFORM_H_
8 #include <dmlc/common.h>
13 #include <type_traits>
20 #if defined (__CUDACC__)
21 #include "device_helpers.cuh"
22 #endif // defined (__CUDACC__)
31 #if defined(__CUDACC__)
32 template <
typename Functor,
typename... SpanType>
33 __global__
void LaunchCUDAKernel(Functor _func,
Range _range,
35 for (
auto i : dh::GridStrideRange(*_range.
begin(), *_range.
end())) {
39 #endif // defined(__CUDACC__)
57 template <
bool CompiledWithCuda = WITH_CUDA()>
60 template <
typename Functor>
63 Evaluator(Functor func,
Range range,
int device,
bool shard) :
64 func_(func), range_{std::move(range)},
74 template <
typename... HDV>
75 void Eval(HDV... vectors)
const {
76 bool on_device = device_ >= 0;
79 LaunchCUDA(func_, vectors...);
81 LaunchCPU(func_, vectors...);
103 template <
typename T>
109 template <
typename T>
113 template <
typename Head,
typename... Rest>
117 SyncHost(_vectors...);
120 template <
typename T>
124 template <
typename Head,
typename... Rest>
125 void UnpackShard(
int device,
129 UnpackShard(device, _vectors...);
132 #if defined(__CUDACC__)
133 template <typename std::enable_if<CompiledWithCuda>::type* =
nullptr,
135 void LaunchCUDA(Functor _func, HDV*... _vectors)
const {
137 UnpackShard(device_, _vectors...);
140 size_t range_size = *range_.end() - *range_.begin();
145 size_t shard_size = range_size;
147 dh::safe_cuda(cudaSetDevice(device_));
153 detail::LaunchCUDAKernel<<<kGrids, kBlockThreads>>>(
154 _func, shard_range, UnpackHDVOnDevice(_vectors)...);
158 template <typename std::enable_if<!CompiledWithCuda>::type* =
nullptr,
160 void LaunchCUDA(Functor _func, HDV*...)
const {
164 LOG(FATAL) <<
"Not part of device code. WITH_CUDA: " <<
WITH_CUDA();
166 #endif // defined(__CUDACC__)
168 template <
typename... HDV>
169 void LaunchCPU(Functor func, HDV*... vectors)
const {
171 dmlc::OMPException omp_exc;
172 SyncHost(vectors...);
173 #pragma omp parallel for schedule(static)
174 for (
omp_ulong idx = 0; idx < end; ++idx) {
175 omp_exc.Run(func, idx, UnpackHDV(vectors)...);
203 template <
typename Functor>
204 static Evaluator<Functor>
Init(Functor func,
Range const range,
206 bool const shard =
true) {
207 return Evaluator<Functor> {func, std::move(range), device, shard};
214 #endif // XGBOOST_COMMON_TRANSFORM_H_