xgboost
poll_utils.h
Go to the documentation of this file.
1 
6 #pragma once
9 
10 #if defined(_WIN32)
11 #include <xgboost/windefs.h>
12 // Socket API
13 #include <winsock2.h>
14 #include <ws2tcpip.h>
15 #else
16 
17 #include <arpa/inet.h>
18 #include <fcntl.h>
19 #include <netdb.h>
20 #include <netinet/in.h>
21 #include <sys/ioctl.h>
22 #include <sys/socket.h>
23 #include <unistd.h>
24 
25 #include <cerrno>
26 
27 #endif // defined(_WIN32)
28 
29 #include <chrono>
30 #include <cstring>
31 #include <string>
32 #include <system_error> // make_error_code, errc
33 #include <unordered_map>
34 #include <vector>
35 
36 #if !defined(_WIN32)
37 
38 #include <poll.h>
39 
40 using SOCKET = int;
41 using sock_size_t = size_t; // NOLINT
42 #endif // !defined(_WIN32)
43 
44 #define IS_MINGW() defined(__MINGW32__)
45 
46 #if IS_MINGW() && !defined(POLLRDNORM) && !defined(POLLRDBAND)
47 /*
48  * On later mingw versions poll should be supported (with bugs). See:
49  * https://stackoverflow.com/a/60623080
50  *
51  * But right now the mingw distributed with R 3.6 doesn't support it.
52  * So we just give a warning and provide dummy implementation to get
53  * compilation passed. Otherwise we will have to provide a stub for
54  * RABIT.
55  *
56  * Even on mingw version that has these structures and flags defined,
57  * functions like `send` and `listen` might have unresolved linkage to
58  * their implementation. So supporting mingw is quite difficult at
59  * the time of writing.
60  */
61 #pragma message("Distributed training on mingw is not supported.")
62 typedef struct pollfd {
63  SOCKET fd;
64  short events; // NOLINT
65  short revents; // NOLINT
66 } WSAPOLLFD, *PWSAPOLLFD, *LPWSAPOLLFD;
67 
68 // POLLRDNORM | POLLRDBAND
69 #define POLLIN (0x0100 | 0x0200)
70 #define POLLPRI 0x0400
71 // POLLWRNORM
72 #define POLLOUT 0x0010
73 
74 #endif // IS_MINGW() && !defined(POLLRDNORM) && !defined(POLLRDBAND)
75 
76 namespace rabit {
77 namespace utils {
78 
79 template <typename PollFD>
80 int PollImpl(PollFD* pfd, int nfds, std::chrono::seconds timeout) noexcept(true) {
81  // For Windows and Linux, negative timeout means infinite timeout. For freebsd,
82  // INFTIM(-1) should be used instead.
83 #if defined(_WIN32)
84 
85 #if IS_MINGW()
86  xgboost::MingWError();
87  return -1;
88 #else
89  return WSAPoll(pfd, nfds, std::chrono::milliseconds(timeout).count());
90 #endif // IS_MINGW()
91 
92 #else
93  return poll(pfd, nfds, timeout.count() < 0 ? -1 : std::chrono::milliseconds(timeout).count());
94 #endif // IS_MINGW()
95 }
96 
97 template <typename E>
98 std::enable_if_t<std::is_integral_v<E>, xgboost::collective::Result> PollError(E const& revents) {
99  if ((revents & POLLERR) != 0) {
100  auto err = errno;
101  auto str = strerror(err);
102  return xgboost::system::FailWithCode(std::string{"Poll error condition:"} + // NOLINT
103  std::string{str} + // NOLINT
104  " code:" + std::to_string(err));
105  }
106  if ((revents & POLLNVAL) != 0) {
107  return xgboost::system::FailWithCode("Invalid polling request.");
108  }
109  if ((revents & POLLHUP) != 0) {
110  // Excerpt from the Linux manual:
111  //
112  // Note that when reading from a channel such as a pipe or a stream socket, this event
113  // merely indicates that the peer closed its end of the channel.Subsequent reads from
114  // the channel will return 0 (end of file) only after all outstanding data in the
115  // channel has been consumed.
116  //
117  // We don't usually have a barrier for exiting workers, it's normal to have one end
118  // exit while the other still reading data.
120  }
121 #if defined(POLLRDHUP)
122  // Linux only flag
123  if ((revents & POLLRDHUP) != 0) {
124  return xgboost::system::FailWithCode("Poll hung up on the other end.");
125  }
126 #endif // defined(POLLRDHUP)
128 }
129 
131 struct PollHelper {
132  public:
137  inline void WatchRead(SOCKET fd) {
138  auto& pfd = fds[fd];
139  pfd.fd = fd;
140  pfd.events |= POLLIN;
141  }
142  void WatchRead(xgboost::collective::TCPSocket const &socket) { this->WatchRead(socket.Handle()); }
143 
148  inline void WatchWrite(SOCKET fd) {
149  auto& pfd = fds[fd];
150  pfd.fd = fd;
151  pfd.events |= POLLOUT;
152  }
154  this->WatchWrite(socket.Handle());
155  }
156 
161  inline void WatchException(SOCKET fd) {
162  auto& pfd = fds[fd];
163  pfd.fd = fd;
164  pfd.events |= POLLPRI;
165  }
167  this->WatchException(socket.Handle());
168  }
173  [[nodiscard]] bool CheckRead(SOCKET fd) const {
174  const auto& pfd = fds.find(fd);
175  return pfd != fds.end() && ((pfd->second.events & POLLIN) != 0);
176  }
177  [[nodiscard]] bool CheckRead(xgboost::collective::TCPSocket const& socket) const {
178  return this->CheckRead(socket.Handle());
179  }
180 
185  [[nodiscard]] bool CheckWrite(SOCKET fd) const {
186  const auto& pfd = fds.find(fd);
187  return pfd != fds.end() && ((pfd->second.events & POLLOUT) != 0);
188  }
189  [[nodiscard]] bool CheckWrite(xgboost::collective::TCPSocket const& socket) const {
190  return this->CheckWrite(socket.Handle());
191  }
197  [[nodiscard]] xgboost::collective::Result Poll(std::chrono::seconds timeout,
198  bool check_error = true) {
199  std::vector<pollfd> fdset;
200  fdset.reserve(fds.size());
201  for (auto kv : fds) {
202  fdset.push_back(kv.second);
203  }
204  std::int32_t ret = PollImpl(fdset.data(), fdset.size(), timeout);
205  if (ret == 0) {
207  "Poll timeout:" + std::to_string(timeout.count()) + " seconds.",
208  std::make_error_code(std::errc::timed_out));
209  } else if (ret < 0) {
210  return xgboost::system::FailWithCode("Poll failed, nfds:" + std::to_string(fdset.size()));
211  }
212 
213  for (auto& pfd : fdset) {
214  auto result = PollError(pfd.revents);
215  if (check_error && !result.OK()) {
216  return result;
217  }
218 
219  auto revents = pfd.revents & pfd.events;
220  fds[pfd.fd].events = revents;
221  }
223  }
224 
225  std::unordered_map<SOCKET, pollfd> fds;
226 };
227 } // namespace utils
228 } // namespace rabit
229 
230 #if IS_MINGW() && !defined(POLLRDNORM) && !defined(POLLRDBAND)
231 #undef POLLIN
232 #undef POLLPRI
233 #undef POLLOUT
234 #endif // IS_MINGW()
TCP socket for simple communication.
Definition: socket.h:267
HandleT const & Handle() const
Return the native socket file descriptor.
Definition: socket.h:539
int PollImpl(PollFD *pfd, int nfds, std::chrono::seconds timeout) noexcept(true)
Definition: poll_utils.h:80
std::enable_if_t< std::is_integral_v< E >, xgboost::collective::Result > PollError(E const &revents)
Definition: poll_utils.h:98
Definition: poll_utils.h:76
auto Fail(std::string msg, char const *file=__builtin_FILE(), std::int32_t line=__builtin_LINE())
Return failure.
Definition: result.h:124
auto Success() noexcept(true)
Return success.
Definition: result.h:120
collective::Result FailWithCode(std::string msg)
Definition: socket.h:78
int SOCKET
Definition: poll_utils.h:40
size_t sock_size_t
Definition: poll_utils.h:41
helper data structure to perform poll
Definition: poll_utils.h:131
void WatchException(SOCKET fd)
add file descriptor to watch for exception
Definition: poll_utils.h:161
bool CheckWrite(xgboost::collective::TCPSocket const &socket) const
Definition: poll_utils.h:189
void WatchRead(xgboost::collective::TCPSocket const &socket)
Definition: poll_utils.h:142
xgboost::collective::Result Poll(std::chrono::seconds timeout, bool check_error=true)
perform poll on the set defined, read, write, exception
Definition: poll_utils.h:197
void WatchWrite(xgboost::collective::TCPSocket const &socket)
Definition: poll_utils.h:153
bool CheckRead(SOCKET fd) const
Check if the descriptor is ready for read.
Definition: poll_utils.h:173
void WatchException(xgboost::collective::TCPSocket const &socket)
Definition: poll_utils.h:166
bool CheckWrite(SOCKET fd) const
Check if the descriptor is ready for write.
Definition: poll_utils.h:185
void WatchWrite(SOCKET fd)
add file descriptor to watch for write
Definition: poll_utils.h:148
bool CheckRead(xgboost::collective::TCPSocket const &socket) const
Definition: poll_utils.h:177
std::unordered_map< SOCKET, pollfd > fds
Definition: poll_utils.h:225
void WatchRead(SOCKET fd)
add file descriptor to watch for read
Definition: poll_utils.h:137
An error type that's easier to handle than throwing dmlc exception. We can record and propagate the s...
Definition: result.h:67