xgboost
socket.h
Go to the documentation of this file.
1 
4 #pragma once
5 
6 #if !defined(NOMINMAX) && defined(_WIN32)
7 #define NOMINMAX
8 #endif // !defined(NOMINMAX)
9 
10 #include <cerrno> // errno, EINTR, EBADF
11 #include <climits> // HOST_NAME_MAX
12 #include <cstddef> // std::size_t
13 #include <cstdint> // std::int32_t, std::uint16_t
14 #include <cstring> // memset
15 #include <limits> // std::numeric_limits
16 #include <string> // std::string
17 #include <system_error> // std::error_code, std::system_category
18 #include <utility> // std::swap
19 
20 #if !defined(xgboost_IS_MINGW)
21 
22 #if defined(__MINGW32__)
23 #define xgboost_IS_MINGW 1
24 #endif // defined(__MINGW32__)
25 
26 #endif // xgboost_IS_MINGW
27 
28 #if defined(_WIN32)
29 
30 #include <winsock2.h>
31 #include <ws2tcpip.h>
32 
33 using in_port_t = std::uint16_t;
34 
35 #ifdef _MSC_VER
36 #pragma comment(lib, "Ws2_32.lib")
37 #endif // _MSC_VER
38 
39 #if !defined(xgboost_IS_MINGW)
40 using ssize_t = int;
41 #endif // !xgboost_IS_MINGW()
42 
43 #else // UNIX
44 
45 #include <arpa/inet.h> // inet_ntop
46 #include <fcntl.h> // fcntl, F_GETFL, O_NONBLOCK
47 #include <netinet/in.h> // sockaddr_in6, sockaddr_in, in_port_t, INET6_ADDRSTRLEN, INET_ADDRSTRLEN
48 #include <netinet/in.h> // IPPROTO_TCP
49 #include <netinet/tcp.h> // TCP_NODELAY
50 #include <sys/socket.h> // socket, SOL_SOCKET, SO_ERROR, MSG_WAITALL, recv, send, AF_INET6, AF_INET
51 #include <unistd.h> // close
52 
53 #if defined(__sun) || defined(sun)
54 #include <sys/sockio.h>
55 #endif // defined(__sun) || defined(sun)
56 
57 #endif // defined(_WIN32)
58 
59 #include "xgboost/base.h" // XGBOOST_EXPECT
60 #include "xgboost/logging.h" // LOG
61 #include "xgboost/string_view.h" // StringView
62 
63 #if !defined(HOST_NAME_MAX)
64 #define HOST_NAME_MAX 256 // macos
65 #endif
66 
67 namespace xgboost {
68 
69 #if defined(xgboost_IS_MINGW)
70 // see the dummy implementation of `poll` in rabit for more info.
71 inline void MingWError() { LOG(FATAL) << "Distributed training on mingw is not supported."; }
72 #endif // defined(xgboost_IS_MINGW)
73 
74 namespace system {
75 inline std::int32_t LastError() {
76 #if defined(_WIN32)
77  return WSAGetLastError();
78 #else
79  int errsv = errno;
80  return errsv;
81 #endif
82 }
83 
84 #if defined(__GLIBC__)
85 inline auto ThrowAtError(StringView fn_name, std::int32_t errsv = LastError(),
86  std::int32_t line = __builtin_LINE(),
87  char const *file = __builtin_FILE()) {
88  auto err = std::error_code{errsv, std::system_category()};
89  LOG(FATAL) << "\n"
90  << file << "(" << line << "): Failed to call `" << fn_name << "`: " << err.message()
91  << std::endl;
92 }
93 #else
94 inline auto ThrowAtError(StringView fn_name, std::int32_t errsv = LastError()) {
95  auto err = std::error_code{errsv, std::system_category()};
96  LOG(FATAL) << "Failed to call `" << fn_name << "`: " << err.message() << std::endl;
97 }
98 #endif // defined(__GLIBC__)
99 
100 #if defined(_WIN32)
101 using SocketT = SOCKET;
102 #else
103 using SocketT = int;
104 #endif // defined(_WIN32)
105 
106 #if !defined(xgboost_CHECK_SYS_CALL)
107 #define xgboost_CHECK_SYS_CALL(exp, expected) \
108  do { \
109  if (XGBOOST_EXPECT((exp) != (expected), false)) { \
110  ::xgboost::system::ThrowAtError(#exp); \
111  } \
112  } while (false)
113 #endif // !defined(xgboost_CHECK_SYS_CALL)
114 
115 inline std::int32_t CloseSocket(SocketT fd) {
116 #if defined(_WIN32)
117  return closesocket(fd);
118 #else
119  return close(fd);
120 #endif
121 }
122 
123 inline bool LastErrorWouldBlock() {
124  int errsv = LastError();
125 #ifdef _WIN32
126  return errsv == WSAEWOULDBLOCK;
127 #else
128  return errsv == EAGAIN || errsv == EWOULDBLOCK;
129 #endif // _WIN32
130 }
131 
132 inline void SocketStartup() {
133 #if defined(_WIN32)
134  WSADATA wsa_data;
135  if (WSAStartup(MAKEWORD(2, 2), &wsa_data) == -1) {
136  ThrowAtError("WSAStartup");
137  }
138  if (LOBYTE(wsa_data.wVersion) != 2 || HIBYTE(wsa_data.wVersion) != 2) {
139  WSACleanup();
140  LOG(FATAL) << "Could not find a usable version of Winsock.dll";
141  }
142 #endif // defined(_WIN32)
143 }
144 
145 inline void SocketFinalize() {
146 #if defined(_WIN32)
147  WSACleanup();
148 #endif // defined(_WIN32)
149 }
150 
151 #if defined(_WIN32) && defined(xgboost_IS_MINGW)
152 // dummy definition for old mysys32.
153 inline const char *inet_ntop(int, const void *, char *, socklen_t) { // NOLINT
154  MingWError();
155  return nullptr;
156 }
157 #else
158 using ::inet_ntop;
159 #endif // defined(_WIN32) && defined(xgboost_IS_MINGW)
160 
161 } // namespace system
162 
163 namespace collective {
164 class SockAddress;
165 
166 enum class SockDomain : std::int32_t { kV4 = AF_INET, kV6 = AF_INET6 };
167 
172 SockAddress MakeSockAddress(StringView host, in_port_t port);
173 
174 class SockAddrV6 {
175  sockaddr_in6 addr_;
176 
177  public:
178  explicit SockAddrV6(sockaddr_in6 addr) : addr_{addr} {}
179  SockAddrV6() { std::memset(&addr_, '\0', sizeof(addr_)); }
180 
183 
184  in_port_t Port() const { return ntohs(addr_.sin6_port); }
185 
186  std::string Addr() const {
187  char buf[INET6_ADDRSTRLEN];
188  auto const *s = system::inet_ntop(static_cast<std::int32_t>(SockDomain::kV6), &addr_.sin6_addr,
189  buf, INET6_ADDRSTRLEN);
190  if (s == nullptr) {
191  system::ThrowAtError("inet_ntop");
192  }
193  return {buf};
194  }
195  sockaddr_in6 const &Handle() const { return addr_; }
196 };
197 
198 class SockAddrV4 {
199  private:
200  sockaddr_in addr_;
201 
202  public:
203  explicit SockAddrV4(sockaddr_in addr) : addr_{addr} {}
204  SockAddrV4() { std::memset(&addr_, '\0', sizeof(addr_)); }
205 
208 
209  in_port_t Port() const { return ntohs(addr_.sin_port); }
210 
211  std::string Addr() const {
212  char buf[INET_ADDRSTRLEN];
213  auto const *s = system::inet_ntop(static_cast<std::int32_t>(SockDomain::kV4), &addr_.sin_addr,
214  buf, INET_ADDRSTRLEN);
215  if (s == nullptr) {
216  system::ThrowAtError("inet_ntop");
217  }
218  return {buf};
219  }
220  sockaddr_in const &Handle() const { return addr_; }
221 };
222 
226 class SockAddress {
227  private:
228  SockAddrV6 v6_;
229  SockAddrV4 v4_;
230  SockDomain domain_{SockDomain::kV4};
231 
232  public:
233  SockAddress() = default;
234  explicit SockAddress(SockAddrV6 const &addr) : v6_{addr}, domain_{SockDomain::kV6} {}
235  explicit SockAddress(SockAddrV4 const &addr) : v4_{addr} {}
236 
237  auto Domain() const { return domain_; }
238 
239  bool IsV4() const { return Domain() == SockDomain::kV4; }
240  bool IsV6() const { return !IsV4(); }
241 
242  auto const &V4() const { return v4_; }
243  auto const &V6() const { return v6_; }
244 };
245 
249 class TCPSocket {
250  public:
252 
253  private:
254  HandleT handle_{InvalidSocket()};
255  // There's reliable no way to extract domain from a socket without first binding that
256  // socket on macos.
257 #if defined(__APPLE__)
258  SockDomain domain_{SockDomain::kV4};
259 #endif
260 
261  constexpr static HandleT InvalidSocket() { return -1; }
262 
263  explicit TCPSocket(HandleT newfd) : handle_{newfd} {}
264 
265  public:
266  TCPSocket() = default;
270  auto Domain() const -> SockDomain {
271  auto ret_iafamily = [](std::int32_t domain) {
272  switch (domain) {
273  case AF_INET:
274  return SockDomain::kV4;
275  case AF_INET6:
276  return SockDomain::kV6;
277  default: {
278  LOG(FATAL) << "Unknown IA family.";
279  }
280  }
281  return SockDomain::kV4;
282  };
283 
284 #if defined(_WIN32)
285  WSAPROTOCOL_INFOA info;
286  socklen_t len = sizeof(info);
288  getsockopt(handle_, SOL_SOCKET, SO_PROTOCOL_INFO, reinterpret_cast<char *>(&info), &len),
289  0);
290  return ret_iafamily(info.iAddressFamily);
291 #elif defined(__APPLE__)
292  return domain_;
293 #elif defined(__unix__)
294 #ifndef __PASE__
295  std::int32_t domain;
296  socklen_t len = sizeof(domain);
298  getsockopt(handle_, SOL_SOCKET, SO_DOMAIN, reinterpret_cast<char *>(&domain), &len), 0);
299  return ret_iafamily(domain);
300 #else
301  struct sockaddr sa;
302  socklen_t sizeofsa = sizeof(sa);
303  xgboost_CHECK_SYS_CALL(getsockname(handle_, &sa, &sizeofsa), 0);
304  if (sizeofsa < sizeof(uchar_t) * 2) {
305  return ret_iafamily(AF_INET);
306  }
307  return ret_iafamily(sa.sa_family);
308 #endif // __PASE__
309 #else
310  LOG(FATAL) << "Unknown platform.";
311  return ret_iafamily(AF_INET);
312 #endif // platforms
313  }
314 
315  bool IsClosed() const { return handle_ == InvalidSocket(); }
316 
318  std::int32_t GetSockError() const {
319  std::int32_t error = 0;
320  socklen_t len = sizeof(error);
322  getsockopt(handle_, SOL_SOCKET, SO_ERROR, reinterpret_cast<char *>(&error), &len), 0);
323  return error;
324  }
326  bool BadSocket() const {
327  if (IsClosed()) return true;
328  std::int32_t err = GetSockError();
329  if (err == EBADF || err == EINTR) return true;
330  return false;
331  }
332 
333  void SetNonBlock() {
334  bool non_block{true};
335 #if defined(_WIN32)
336  u_long mode = non_block ? 1 : 0;
337  xgboost_CHECK_SYS_CALL(ioctlsocket(handle_, FIONBIO, &mode), NO_ERROR);
338 #else
339  std::int32_t flag = fcntl(handle_, F_GETFL, 0);
340  if (flag == -1) {
341  system::ThrowAtError("fcntl");
342  }
343  if (non_block) {
344  flag |= O_NONBLOCK;
345  } else {
346  flag &= ~O_NONBLOCK;
347  }
348  if (fcntl(handle_, F_SETFL, flag) == -1) {
349  system::ThrowAtError("fcntl");
350  }
351 #endif // _WIN32
352  }
353 
354  void SetKeepAlive() {
355  std::int32_t keepalive = 1;
356  xgboost_CHECK_SYS_CALL(setsockopt(handle_, SOL_SOCKET, SO_KEEPALIVE,
357  reinterpret_cast<char *>(&keepalive), sizeof(keepalive)),
358  0);
359  }
360 
361  void SetNoDelay() {
362  std::int32_t tcp_no_delay = 1;
364  setsockopt(handle_, IPPROTO_TCP, TCP_NODELAY, reinterpret_cast<char *>(&tcp_no_delay),
365  sizeof(tcp_no_delay)),
366  0);
367  }
368 
373  HandleT newfd = accept(handle_, nullptr, nullptr);
374  if (newfd == InvalidSocket()) {
375  system::ThrowAtError("accept");
376  }
377  TCPSocket newsock{newfd};
378  return newsock;
379  }
380 
382  if (!IsClosed()) {
383  Close();
384  }
385  }
386 
387  TCPSocket(TCPSocket const &that) = delete;
388  TCPSocket(TCPSocket &&that) noexcept(true) { std::swap(this->handle_, that.handle_); }
389  TCPSocket &operator=(TCPSocket const &that) = delete;
391  std::swap(this->handle_, that.handle_);
392  return *this;
393  }
397  HandleT const &Handle() const { return handle_; }
401  void Listen(std::int32_t backlog = 16) { xgboost_CHECK_SYS_CALL(listen(handle_, backlog), 0); }
405  in_port_t BindHost() {
406  if (Domain() == SockDomain::kV6) {
407  auto addr = SockAddrV6::InaddrAny();
408  auto handle = reinterpret_cast<sockaddr const *>(&addr.Handle());
410  bind(handle_, handle, sizeof(std::remove_reference_t<decltype(addr.Handle())>)), 0);
411 
412  sockaddr_in6 res_addr;
413  socklen_t addrlen = sizeof(res_addr);
415  getsockname(handle_, reinterpret_cast<sockaddr *>(&res_addr), &addrlen), 0);
416  return ntohs(res_addr.sin6_port);
417  } else {
418  auto addr = SockAddrV4::InaddrAny();
419  auto handle = reinterpret_cast<sockaddr const *>(&addr.Handle());
421  bind(handle_, handle, sizeof(std::remove_reference_t<decltype(addr.Handle())>)), 0);
422 
423  sockaddr_in res_addr;
424  socklen_t addrlen = sizeof(res_addr);
426  getsockname(handle_, reinterpret_cast<sockaddr *>(&res_addr), &addrlen), 0);
427  return ntohs(res_addr.sin_port);
428  }
429  }
433  auto SendAll(void const *buf, std::size_t len) {
434  char const *_buf = reinterpret_cast<const char *>(buf);
435  std::size_t ndone = 0;
436  while (ndone < len) {
437  ssize_t ret = send(handle_, _buf, len - ndone, 0);
438  if (ret == -1) {
440  return ndone;
441  }
442  system::ThrowAtError("send");
443  }
444  _buf += ret;
445  ndone += ret;
446  }
447  return ndone;
448  }
452  auto RecvAll(void *buf, std::size_t len) {
453  char *_buf = reinterpret_cast<char *>(buf);
454  std::size_t ndone = 0;
455  while (ndone < len) {
456  ssize_t ret = recv(handle_, _buf, len - ndone, MSG_WAITALL);
457  if (ret == -1) {
459  return ndone;
460  }
461  system::ThrowAtError("recv");
462  }
463  if (ret == 0) {
464  return ndone;
465  }
466  _buf += ret;
467  ndone += ret;
468  }
469  return ndone;
470  }
478  auto Send(const void *buf_, std::size_t len, std::int32_t flags = 0) {
479  const char *buf = reinterpret_cast<const char *>(buf_);
480  return send(handle_, buf, len, flags);
481  }
489  auto Recv(void *buf, std::size_t len, std::int32_t flags = 0) {
490  char *_buf = reinterpret_cast<char *>(buf);
491  return recv(handle_, _buf, len, flags);
492  }
496  std::size_t Send(StringView str);
500  std::size_t Recv(std::string *p_str);
504  void Close() {
505  if (InvalidSocket() != handle_) {
507  handle_ = InvalidSocket();
508  }
509  }
513  static TCPSocket Create(SockDomain domain) {
514 #if defined(xgboost_IS_MINGW)
515  MingWError();
516  return {};
517 #else
518  auto fd = socket(static_cast<std::int32_t>(domain), SOCK_STREAM, 0);
519  if (fd == InvalidSocket()) {
520  system::ThrowAtError("socket");
521  }
522 
523  TCPSocket socket{fd};
524 #if defined(__APPLE__)
525  socket.domain_ = domain;
526 #endif // defined(__APPLE__)
527  return socket;
528 #endif // defined(xgboost_IS_MINGW)
529  }
530 };
531 
536 std::error_code Connect(SockAddress const &addr, TCPSocket *out);
537 
541 inline std::string GetHostName() {
542  char buf[HOST_NAME_MAX];
543  xgboost_CHECK_SYS_CALL(gethostname(&buf[0], HOST_NAME_MAX), 0);
544  return buf;
545 }
546 } // namespace collective
547 } // namespace xgboost
548 
549 #undef xgboost_CHECK_SYS_CALL
550 
551 #if defined(xgboost_IS_MINGW)
552 #undef xgboost_IS_MINGW
553 #endif
Defines configuration macros and basic types for xgboost.
Definition: socket.h:198
SockAddrV4(sockaddr_in addr)
Definition: socket.h:203
static SockAddrV4 InaddrAny()
in_port_t Port() const
Definition: socket.h:209
sockaddr_in const & Handle() const
Definition: socket.h:220
std::string Addr() const
Definition: socket.h:211
static SockAddrV4 Loopback()
SockAddrV4()
Definition: socket.h:204
Definition: socket.h:174
static SockAddrV6 InaddrAny()
SockAddrV6()
Definition: socket.h:179
sockaddr_in6 const & Handle() const
Definition: socket.h:195
in_port_t Port() const
Definition: socket.h:184
SockAddrV6(sockaddr_in6 addr)
Definition: socket.h:178
std::string Addr() const
Definition: socket.h:186
static SockAddrV6 Loopback()
Address for TCP socket, can be either IPv4 or IPv6.
Definition: socket.h:226
bool IsV6() const
Definition: socket.h:240
auto const & V6() const
Definition: socket.h:243
bool IsV4() const
Definition: socket.h:239
auto Domain() const
Definition: socket.h:237
SockAddress(SockAddrV4 const &addr)
Definition: socket.h:235
auto const & V4() const
Definition: socket.h:242
SockAddress(SockAddrV6 const &addr)
Definition: socket.h:234
TCP socket for simple communication.
Definition: socket.h:249
HandleT const & Handle() const
Return the native socket file descriptor.
Definition: socket.h:397
void SetNoDelay()
Definition: socket.h:361
std::int32_t GetSockError() const
get last error code if any
Definition: socket.h:318
TCPSocket & operator=(TCPSocket const &that)=delete
auto SendAll(void const *buf, std::size_t len)
Send data, without error then all data should be sent.
Definition: socket.h:433
system::SocketT HandleT
Definition: socket.h:251
void SetNonBlock()
Definition: socket.h:333
void Close()
Close the socket, called automatically in destructor if the socket is not closed.
Definition: socket.h:504
void Listen(std::int32_t backlog=16)
Listen to incoming requests. Should be called after bind.
Definition: socket.h:401
TCPSocket(TCPSocket const &that)=delete
std::size_t Send(StringView str)
Send string, format is matched with the Python socket wrapper in RABIT.
TCPSocket & operator=(TCPSocket &&that)
Definition: socket.h:390
auto Domain() const -> SockDomain
Return the socket domain.
Definition: socket.h:270
static TCPSocket Create(SockDomain domain)
Create a TCP socket on specified domain.
Definition: socket.h:513
auto Recv(void *buf, std::size_t len, std::int32_t flags=0)
receive data using the socket
Definition: socket.h:489
auto RecvAll(void *buf, std::size_t len)
Receive data, without error then all data should be received.
Definition: socket.h:452
bool IsClosed() const
Definition: socket.h:315
TCPSocket Accept()
Accept new connection, returns a new TCP socket for the new connection.
Definition: socket.h:372
~TCPSocket()
Definition: socket.h:381
auto Send(const void *buf_, std::size_t len, std::int32_t flags=0)
Send data using the socket.
Definition: socket.h:478
in_port_t BindHost()
Bind socket to INADDR_ANY, return the port selected by the OS.
Definition: socket.h:405
void SetKeepAlive()
Definition: socket.h:354
std::size_t Recv(std::string *p_str)
Receive string, format is matched with the Python socket wrapper in RABIT.
bool BadSocket() const
check if anything bad happens
Definition: socket.h:326
TCPSocket(TCPSocket &&that) noexcept(true)
Definition: socket.h:388
void swap(xgboost::IntrusivePtr< T > &x, xgboost::IntrusivePtr< T > &y) noexcept
Definition: intrusive_ptr.h:209
std::error_code Connect(SockAddress const &addr, TCPSocket *out)
Connect to remote address, returns the error code if failed (no exception is raised so that we can re...
std::string GetHostName()
Get the local host name.
Definition: socket.h:541
SockAddress MakeSockAddress(StringView host, in_port_t port)
Parse host address and return a SockAddress instance. Supports IPv4 and IPv6 host.
SockDomain
Definition: socket.h:166
auto ThrowAtError(StringView fn_name, std::int32_t errsv=LastError())
Definition: socket.h:94
void SocketStartup()
Definition: socket.h:132
std::int32_t CloseSocket(SocketT fd)
Definition: socket.h:115
bool LastErrorWouldBlock()
Definition: socket.h:123
std::int32_t LastError()
Definition: socket.h:75
void SocketFinalize()
Definition: socket.h:145
int SocketT
Definition: socket.h:103
namespace of xgboost
Definition: base.h:90
#define xgboost_CHECK_SYS_CALL(exp, expected)
Definition: socket.h:107
#define HOST_NAME_MAX
Definition: socket.h:64
Definition: string_view.h:15