xgboost
socket.h
Go to the documentation of this file.
1 
4 #pragma once
5 
6 #if !defined(NOMINMAX) && defined(_WIN32)
7 #define NOMINMAX
8 #endif // !defined(NOMINMAX)
9 
10 #include <cerrno> // errno, EINTR, EBADF
11 #include <climits> // HOST_NAME_MAX
12 #include <cstddef> // std::size_t
13 #include <cstdint> // std::int32_t, std::uint16_t
14 #include <cstring> // memset
15 #include <string> // std::string
16 #include <system_error> // std::error_code, std::system_category
17 #include <utility> // std::swap
18 
19 #if !defined(xgboost_IS_MINGW)
20 
21 #if defined(__MINGW32__)
22 #define xgboost_IS_MINGW 1
23 #endif // defined(__MINGW32__)
24 
25 #endif // xgboost_IS_MINGW
26 
27 #if defined(_WIN32)
28 
29 #include <winsock2.h>
30 #include <ws2tcpip.h>
31 
32 using in_port_t = std::uint16_t;
33 
34 #ifdef _MSC_VER
35 #pragma comment(lib, "Ws2_32.lib")
36 #endif // _MSC_VER
37 
38 #if !defined(xgboost_IS_MINGW)
39 using ssize_t = int;
40 #endif // !xgboost_IS_MINGW()
41 
42 #else // UNIX
43 
44 #include <arpa/inet.h> // inet_ntop
45 #include <fcntl.h> // fcntl, F_GETFL, O_NONBLOCK
46 #include <netinet/in.h> // sockaddr_in6, sockaddr_in, in_port_t, INET6_ADDRSTRLEN, INET_ADDRSTRLEN
47 #include <netinet/in.h> // IPPROTO_TCP
48 #include <netinet/tcp.h> // TCP_NODELAY
49 #include <sys/socket.h> // socket, SOL_SOCKET, SO_ERROR, MSG_WAITALL, recv, send, AF_INET6, AF_INET
50 #include <unistd.h> // close
51 
52 #if defined(__sun) || defined(sun)
53 #include <sys/sockio.h>
54 #endif // defined(__sun) || defined(sun)
55 
56 #endif // defined(_WIN32)
57 
58 #include "xgboost/base.h" // XGBOOST_EXPECT
59 #include "xgboost/collective/result.h" // for Result
60 #include "xgboost/logging.h" // LOG
61 #include "xgboost/string_view.h" // StringView
62 
63 #if !defined(HOST_NAME_MAX)
64 #define HOST_NAME_MAX 256 // macos
65 #endif
66 
67 namespace xgboost {
68 
69 #if defined(xgboost_IS_MINGW)
70 // see the dummy implementation of `poll` in rabit for more info.
71 inline void MingWError() { LOG(FATAL) << "Distributed training on mingw is not supported."; }
72 #endif // defined(xgboost_IS_MINGW)
73 
74 namespace system {
75 inline std::int32_t LastError() {
76 #if defined(_WIN32)
77  return WSAGetLastError();
78 #else
79  int errsv = errno;
80  return errsv;
81 #endif
82 }
83 
84 [[nodiscard]] inline collective::Result FailWithCode(std::string msg) {
85  return collective::Fail(std::move(msg), std::error_code{LastError(), std::system_category()});
86 }
87 
88 #if defined(__GLIBC__)
89 inline auto ThrowAtError(StringView fn_name, std::int32_t errsv = LastError(),
90  std::int32_t line = __builtin_LINE(),
91  char const *file = __builtin_FILE()) {
92  auto err = std::error_code{errsv, std::system_category()};
93  LOG(FATAL) << "\n"
94  << file << "(" << line << "): Failed to call `" << fn_name << "`: " << err.message()
95  << std::endl;
96 }
97 #else
98 inline auto ThrowAtError(StringView fn_name, std::int32_t errsv = LastError()) {
99  auto err = std::error_code{errsv, std::system_category()};
100  LOG(FATAL) << "Failed to call `" << fn_name << "`: " << err.message() << std::endl;
101 }
102 #endif // defined(__GLIBC__)
103 
104 #if defined(_WIN32)
105 using SocketT = SOCKET;
106 #else
107 using SocketT = int;
108 #endif // defined(_WIN32)
109 
110 #if !defined(xgboost_CHECK_SYS_CALL)
111 #define xgboost_CHECK_SYS_CALL(exp, expected) \
112  do { \
113  if (XGBOOST_EXPECT((exp) != (expected), false)) { \
114  ::xgboost::system::ThrowAtError(#exp); \
115  } \
116  } while (false)
117 #endif // !defined(xgboost_CHECK_SYS_CALL)
118 
119 inline std::int32_t CloseSocket(SocketT fd) {
120 #if defined(_WIN32)
121  return closesocket(fd);
122 #else
123  return close(fd);
124 #endif
125 }
126 
127 inline std::int32_t ShutdownSocket(SocketT fd) {
128 #if defined(_WIN32)
129  auto rc = shutdown(fd, SD_BOTH);
130  if (rc != 0 && LastError() == WSANOTINITIALISED) {
131  return 0;
132  }
133 #else
134  auto rc = shutdown(fd, SHUT_RDWR);
135  if (rc != 0 && LastError() == ENOTCONN) {
136  return 0;
137  }
138 #endif
139  return rc;
140 }
141 
142 inline bool ErrorWouldBlock(std::int32_t errsv) noexcept(true) {
143 #ifdef _WIN32
144  return errsv == WSAEWOULDBLOCK;
145 #else
146  return errsv == EAGAIN || errsv == EWOULDBLOCK || errsv == EINPROGRESS;
147 #endif // _WIN32
148 }
149 
150 inline bool LastErrorWouldBlock() {
151  int errsv = LastError();
152  return ErrorWouldBlock(errsv);
153 }
154 
155 inline void SocketStartup() {
156 #if defined(_WIN32)
157  WSADATA wsa_data;
158  if (WSAStartup(MAKEWORD(2, 2), &wsa_data) == -1) {
159  ThrowAtError("WSAStartup");
160  }
161  if (LOBYTE(wsa_data.wVersion) != 2 || HIBYTE(wsa_data.wVersion) != 2) {
162  WSACleanup();
163  LOG(FATAL) << "Could not find a usable version of Winsock.dll";
164  }
165 #endif // defined(_WIN32)
166 }
167 
168 inline void SocketFinalize() {
169 #if defined(_WIN32)
170  WSACleanup();
171 #endif // defined(_WIN32)
172 }
173 
174 #if defined(_WIN32) && defined(xgboost_IS_MINGW)
175 // dummy definition for old mysys32.
176 inline const char *inet_ntop(int, const void *, char *, socklen_t) { // NOLINT
177  MingWError();
178  return nullptr;
179 }
180 #else
181 using ::inet_ntop;
182 #endif // defined(_WIN32) && defined(xgboost_IS_MINGW)
183 
184 } // namespace system
185 
186 namespace collective {
187 class SockAddress;
188 
189 enum class SockDomain : std::int32_t { kV4 = AF_INET, kV6 = AF_INET6 };
190 
195 SockAddress MakeSockAddress(StringView host, in_port_t port);
196 
197 class SockAddrV6 {
198  sockaddr_in6 addr_;
199 
200  public:
201  explicit SockAddrV6(sockaddr_in6 addr) : addr_{addr} {}
202  SockAddrV6() { std::memset(&addr_, '\0', sizeof(addr_)); }
203 
206 
207  in_port_t Port() const { return ntohs(addr_.sin6_port); }
208 
209  std::string Addr() const {
210  char buf[INET6_ADDRSTRLEN];
211  auto const *s = system::inet_ntop(static_cast<std::int32_t>(SockDomain::kV6), &addr_.sin6_addr,
212  buf, INET6_ADDRSTRLEN);
213  if (s == nullptr) {
214  system::ThrowAtError("inet_ntop");
215  }
216  return {buf};
217  }
218  sockaddr_in6 const &Handle() const { return addr_; }
219 };
220 
221 class SockAddrV4 {
222  private:
223  sockaddr_in addr_;
224 
225  public:
226  explicit SockAddrV4(sockaddr_in addr) : addr_{addr} {}
227  SockAddrV4() { std::memset(&addr_, '\0', sizeof(addr_)); }
228 
231 
232  [[nodiscard]] in_port_t Port() const { return ntohs(addr_.sin_port); }
233 
234  [[nodiscard]] std::string Addr() const {
235  char buf[INET_ADDRSTRLEN];
236  auto const *s = system::inet_ntop(static_cast<std::int32_t>(SockDomain::kV4), &addr_.sin_addr,
237  buf, INET_ADDRSTRLEN);
238  if (s == nullptr) {
239  system::ThrowAtError("inet_ntop");
240  }
241  return {buf};
242  }
243  [[nodiscard]] sockaddr_in const &Handle() const { return addr_; }
244 };
245 
249 class SockAddress {
250  private:
251  SockAddrV6 v6_;
252  SockAddrV4 v4_;
253  SockDomain domain_{SockDomain::kV4};
254 
255  public:
256  SockAddress() = default;
257  explicit SockAddress(SockAddrV6 const &addr) : v6_{addr}, domain_{SockDomain::kV6} {}
258  explicit SockAddress(SockAddrV4 const &addr) : v4_{addr} {}
259 
260  [[nodiscard]] auto Domain() const { return domain_; }
261 
262  [[nodiscard]] bool IsV4() const { return Domain() == SockDomain::kV4; }
263  [[nodiscard]] bool IsV6() const { return !IsV4(); }
264 
265  [[nodiscard]] auto const &V4() const { return v4_; }
266  [[nodiscard]] auto const &V6() const { return v6_; }
267 };
268 
272 class TCPSocket {
273  public:
275 
276  private:
277  HandleT handle_{InvalidSocket()};
278  bool non_blocking_{false};
279  // There's reliable no way to extract domain from a socket without first binding that
280  // socket on macos.
281 #if defined(__APPLE__)
282  SockDomain domain_{SockDomain::kV4};
283 #endif
284 
285  constexpr static HandleT InvalidSocket() { return -1; }
286 
287  explicit TCPSocket(HandleT newfd) : handle_{newfd} {}
288 
289  public:
290  TCPSocket() = default;
294  [[nodiscard]] auto Domain() const -> SockDomain {
295  auto ret_iafamily = [](std::int32_t domain) {
296  switch (domain) {
297  case AF_INET:
298  return SockDomain::kV4;
299  case AF_INET6:
300  return SockDomain::kV6;
301  default: {
302  LOG(FATAL) << "Unknown IA family.";
303  }
304  }
305  return SockDomain::kV4;
306  };
307 
308 #if defined(_WIN32)
309  WSAPROTOCOL_INFOA info;
310  socklen_t len = sizeof(info);
312  getsockopt(handle_, SOL_SOCKET, SO_PROTOCOL_INFO, reinterpret_cast<char *>(&info), &len),
313  0);
314  return ret_iafamily(info.iAddressFamily);
315 #elif defined(__APPLE__)
316  return domain_;
317 #elif defined(__unix__)
318 #ifndef __PASE__
319  std::int32_t domain;
320  socklen_t len = sizeof(domain);
322  getsockopt(handle_, SOL_SOCKET, SO_DOMAIN, reinterpret_cast<char *>(&domain), &len), 0);
323  return ret_iafamily(domain);
324 #else
325  struct sockaddr sa;
326  socklen_t sizeofsa = sizeof(sa);
327  xgboost_CHECK_SYS_CALL(getsockname(handle_, &sa, &sizeofsa), 0);
328  if (sizeofsa < sizeof(uchar_t) * 2) {
329  return ret_iafamily(AF_INET);
330  }
331  return ret_iafamily(sa.sa_family);
332 #endif // __PASE__
333 #else
334  LOG(FATAL) << "Unknown platform.";
335  return ret_iafamily(AF_INET);
336 #endif // platforms
337  }
338 
339  [[nodiscard]] bool IsClosed() const { return handle_ == InvalidSocket(); }
340 
342  [[nodiscard]] Result GetSockError() const {
343  std::int32_t optval = 0;
344  socklen_t len = sizeof(optval);
345  auto ret = getsockopt(handle_, SOL_SOCKET, SO_ERROR, reinterpret_cast<char *>(&optval), &len);
346  if (ret != 0) {
347  auto errc = std::error_code{system::LastError(), std::system_category()};
348  return Fail("Failed to retrieve socket error.", std::move(errc));
349  }
350  if (optval != 0) {
351  auto errc = std::error_code{optval, std::system_category()};
352  return Fail("Socket error.", std::move(errc));
353  }
354  return Success();
355  }
356 
358  [[nodiscard]] bool BadSocket() const {
359  if (IsClosed()) {
360  return true;
361  }
362  auto err = GetSockError();
363  if (err.Code() == std::error_code{EBADF, std::system_category()} || // NOLINT
364  err.Code() == std::error_code{EINTR, std::system_category()}) { // NOLINT
365  return true;
366  }
367  return false;
368  }
369 
370  [[nodiscard]] Result NonBlocking(bool non_block) {
371 #if defined(_WIN32)
372  u_long mode = non_block ? 1 : 0;
373  if (ioctlsocket(handle_, FIONBIO, &mode) != NO_ERROR) {
374  return system::FailWithCode("Failed to set socket to non-blocking.");
375  }
376 #else
377  std::int32_t flag = fcntl(handle_, F_GETFL, 0);
378  auto rc = flag;
379  if (rc == -1) {
380  return system::FailWithCode("Failed to get socket flag.");
381  }
382  if (non_block) {
383  flag |= O_NONBLOCK;
384  } else {
385  flag &= ~O_NONBLOCK;
386  }
387  rc = fcntl(handle_, F_SETFL, flag);
388  if (rc == -1) {
389  return system::FailWithCode("Failed to set socket to non-blocking.");
390  }
391 #endif // _WIN32
392  non_blocking_ = non_block;
393  return Success();
394  }
395  [[nodiscard]] bool NonBlocking() const { return non_blocking_; }
396  [[nodiscard]] Result RecvTimeout(std::chrono::seconds timeout) {
397  // https://stackoverflow.com/questions/2876024/linux-is-there-a-read-or-recv-from-socket-with-timeout
398 #if defined(_WIN32)
399  DWORD tv = timeout.count() * 1000;
400  auto rc =
401  setsockopt(Handle(), SOL_SOCKET, SO_RCVTIMEO, reinterpret_cast<char *>(&tv), sizeof(tv));
402 #else
403  struct timeval tv;
404  tv.tv_sec = timeout.count();
405  tv.tv_usec = 0;
406  auto rc = setsockopt(Handle(), SOL_SOCKET, SO_RCVTIMEO, reinterpret_cast<char const *>(&tv),
407  sizeof(tv));
408 #endif
409  if (rc != 0) {
410  return system::FailWithCode("Failed to set timeout on recv.");
411  }
412  return Success();
413  }
414 
415  [[nodiscard]] Result SetBufSize(std::int32_t n_bytes) {
416  auto rc = setsockopt(this->Handle(), SOL_SOCKET, SO_SNDBUF, reinterpret_cast<char *>(&n_bytes),
417  sizeof(n_bytes));
418  if (rc != 0) {
419  return system::FailWithCode("Failed to set send buffer size.");
420  }
421  rc = setsockopt(this->Handle(), SOL_SOCKET, SO_RCVBUF, reinterpret_cast<char *>(&n_bytes),
422  sizeof(n_bytes));
423  if (rc != 0) {
424  return system::FailWithCode("Failed to set recv buffer size.");
425  }
426  return Success();
427  }
428 
429  [[nodiscard]] Result SetKeepAlive() {
430  std::int32_t keepalive = 1;
431  auto rc = setsockopt(handle_, SOL_SOCKET, SO_KEEPALIVE, reinterpret_cast<char *>(&keepalive),
432  sizeof(keepalive));
433  if (rc != 0) {
434  return system::FailWithCode("Failed to set TCP keeaplive.");
435  }
436  return Success();
437  }
438 
439  [[nodiscard]] Result SetNoDelay() {
440  std::int32_t tcp_no_delay = 1;
441  auto rc = setsockopt(handle_, IPPROTO_TCP, TCP_NODELAY, reinterpret_cast<char *>(&tcp_no_delay),
442  sizeof(tcp_no_delay));
443  if (rc != 0) {
444  return system::FailWithCode("Failed to set TCP no delay.");
445  }
446  return Success();
447  }
448 
453  SockAddress addr;
454  TCPSocket newsock;
455  auto rc = this->Accept(&newsock, &addr);
456  SafeColl(rc);
457  return newsock;
458  }
459 
460  [[nodiscard]] Result Accept(TCPSocket *out, SockAddress *addr) {
461 #if defined(_WIN32)
462  auto interrupt = WSAEINTR;
463 #else
464  auto interrupt = EINTR;
465 #endif
466  if (this->Domain() == SockDomain::kV4) {
467  struct sockaddr_in caddr;
468  socklen_t caddr_len = sizeof(caddr);
469  HandleT newfd = accept(Handle(), reinterpret_cast<sockaddr *>(&caddr), &caddr_len);
470  if (newfd == InvalidSocket() && system::LastError() != interrupt) {
471  return system::FailWithCode("Failed to accept.");
472  }
473  *addr = SockAddress{SockAddrV4{caddr}};
474  *out = TCPSocket{newfd};
475  } else {
476  struct sockaddr_in6 caddr;
477  socklen_t caddr_len = sizeof(caddr);
478  HandleT newfd = accept(Handle(), reinterpret_cast<sockaddr *>(&caddr), &caddr_len);
479  if (newfd == InvalidSocket() && system::LastError() != interrupt) {
480  return system::FailWithCode("Failed to accept.");
481  }
482  *addr = SockAddress{SockAddrV6{caddr}};
483  *out = TCPSocket{newfd};
484  }
485  // On MacOS, this is automatically set to async socket if the parent socket is async
486  // We make sure all socket are blocking by default.
487  //
488  // On Windows, a closed socket is returned during shutdown. We guard against it when
489  // setting non-blocking.
490  if (!out->IsClosed()) {
491  return out->NonBlocking(false);
492  }
493  return Success();
494  }
495 
497  if (!IsClosed()) {
498  auto rc = this->Close();
499  if (!rc.OK()) {
500  LOG(WARNING) << rc.Report();
501  }
502  }
503  }
504 
505  TCPSocket(TCPSocket const &that) = delete;
506  TCPSocket(TCPSocket &&that) noexcept(true) { std::swap(this->handle_, that.handle_); }
507  TCPSocket &operator=(TCPSocket const &that) = delete;
508  TCPSocket &operator=(TCPSocket &&that) noexcept(true) {
509  std::swap(this->handle_, that.handle_);
510  return *this;
511  }
515  [[nodiscard]] HandleT const &Handle() const { return handle_; }
519  [[nodiscard]] Result Listen(std::int32_t backlog = 16) {
520  if (listen(handle_, backlog) != 0) {
521  return system::FailWithCode("Failed to listen.");
522  }
523  return Success();
524  }
528  [[nodiscard]] Result BindHost(std::int32_t* p_out) {
529  // Use int32 instead of in_port_t for consistency. We take port as parameter from
530  // users using other languages, the port is usually stored and passed around as int.
531  if (Domain() == SockDomain::kV6) {
532  auto addr = SockAddrV6::InaddrAny();
533  auto handle = reinterpret_cast<sockaddr const *>(&addr.Handle());
534  if (bind(handle_, handle, sizeof(std::remove_reference_t<decltype(addr.Handle())>)) != 0) {
535  return system::FailWithCode("bind failed.");
536  }
537 
538  sockaddr_in6 res_addr;
539  socklen_t addrlen = sizeof(res_addr);
540  if (getsockname(handle_, reinterpret_cast<sockaddr *>(&res_addr), &addrlen) != 0) {
541  return system::FailWithCode("getsockname failed.");
542  }
543  *p_out = ntohs(res_addr.sin6_port);
544  } else {
545  auto addr = SockAddrV4::InaddrAny();
546  auto handle = reinterpret_cast<sockaddr const *>(&addr.Handle());
547  if (bind(handle_, handle, sizeof(std::remove_reference_t<decltype(addr.Handle())>)) != 0) {
548  return system::FailWithCode("bind failed.");
549  }
550 
551  sockaddr_in res_addr;
552  socklen_t addrlen = sizeof(res_addr);
553  if (getsockname(handle_, reinterpret_cast<sockaddr *>(&res_addr), &addrlen) != 0) {
554  return system::FailWithCode("getsockname failed.");
555  }
556  *p_out = ntohs(res_addr.sin_port);
557  }
558 
559  return Success();
560  }
561 
562  [[nodiscard]] auto Port() const {
563  if (this->Domain() == SockDomain::kV4) {
564  sockaddr_in res_addr;
565  socklen_t addrlen = sizeof(res_addr);
566  auto code = getsockname(handle_, reinterpret_cast<sockaddr *>(&res_addr), &addrlen);
567  if (code != 0) {
568  return std::make_pair(system::FailWithCode("getsockname"), std::int32_t{0});
569  }
570  return std::make_pair(Success(), std::int32_t{ntohs(res_addr.sin_port)});
571  } else {
572  sockaddr_in6 res_addr;
573  socklen_t addrlen = sizeof(res_addr);
574  auto code = getsockname(handle_, reinterpret_cast<sockaddr *>(&res_addr), &addrlen);
575  if (code != 0) {
576  return std::make_pair(system::FailWithCode("getsockname"), std::int32_t{0});
577  }
578  return std::make_pair(Success(), std::int32_t{ntohs(res_addr.sin6_port)});
579  }
580  }
581 
582  [[nodiscard]] Result Bind(StringView ip, std::int32_t *port) {
583  // bind socket handle_ to ip
584  auto addr = MakeSockAddress(ip, 0);
585  std::int32_t errc{0};
586  if (addr.IsV4()) {
587  auto handle = reinterpret_cast<sockaddr const *>(&addr.V4().Handle());
588  errc = bind(handle_, handle, sizeof(std::remove_reference_t<decltype(addr.V4().Handle())>));
589  } else {
590  auto handle = reinterpret_cast<sockaddr const *>(&addr.V6().Handle());
591  errc = bind(handle_, handle, sizeof(std::remove_reference_t<decltype(addr.V6().Handle())>));
592  }
593  if (errc != 0) {
594  return system::FailWithCode("Failed to bind socket.");
595  }
596  auto [rc, new_port] = this->Port();
597  if (!rc.OK()) {
598  return std::move(rc);
599  }
600  *port = new_port;
601  return Success();
602  }
603 
607  [[nodiscard]] auto SendAll(void const *buf, std::size_t len) {
608  char const *_buf = reinterpret_cast<const char *>(buf);
609  std::size_t ndone = 0;
610  while (ndone < len) {
611  ssize_t ret = send(handle_, _buf, len - ndone, 0);
612  if (ret == -1) {
614  return ndone;
615  }
616  system::ThrowAtError("send");
617  }
618  _buf += ret;
619  ndone += ret;
620  }
621  return ndone;
622  }
626  [[nodiscard]] auto RecvAll(void *buf, std::size_t len) {
627  char *_buf = reinterpret_cast<char *>(buf);
628  std::size_t ndone = 0;
629  while (ndone < len) {
630  ssize_t ret = recv(handle_, _buf, len - ndone, MSG_WAITALL);
631  if (ret == -1) {
633  return ndone;
634  }
635  system::ThrowAtError("recv");
636  }
637  if (ret == 0) {
638  return ndone;
639  }
640  _buf += ret;
641  ndone += ret;
642  }
643  return ndone;
644  }
652  auto Send(const void *buf_, std::size_t len, std::int32_t flags = 0) {
653  const char *buf = reinterpret_cast<const char *>(buf_);
654  return send(handle_, buf, len, flags);
655  }
663  auto Recv(void *buf, std::size_t len, std::int32_t flags = 0) {
664  char *_buf = reinterpret_cast<char *>(buf);
665  return recv(handle_, _buf, len, flags);
666  }
670  std::size_t Send(StringView str);
674  [[nodiscard]] Result Recv(std::string *p_str);
678  [[nodiscard]] Result Close() {
679  if (InvalidSocket() != handle_) {
680  auto rc = system::CloseSocket(handle_);
681 #if defined(_WIN32)
682  // it's possible that we close TCP sockets after finalizing WSA due to detached thread.
683  if (rc != 0 && system::LastError() != WSANOTINITIALISED) {
684  return system::FailWithCode("Failed to close the socket.");
685  }
686 #else
687  if (rc != 0) {
688  return system::FailWithCode("Failed to close the socket.");
689  }
690 #endif
691  handle_ = InvalidSocket();
692  }
693  return Success();
694  }
698  [[nodiscard]] Result Shutdown() {
699  if (this->IsClosed()) {
700  return Success();
701  }
702  auto rc = system::ShutdownSocket(this->Handle());
703 #if defined(_WIN32)
704  // Windows cannot shutdown a socket if it's not connected.
705  if (rc == -1 && system::LastError() == WSAENOTCONN) {
706  return Success();
707  }
708 #endif
709  if (rc != 0) {
710  return system::FailWithCode("Failed to shutdown socket.");
711  }
712  return Success();
713  }
714 
718  static TCPSocket Create(SockDomain domain) {
719 #if defined(xgboost_IS_MINGW)
720  MingWError();
721  return {};
722 #else
723  auto fd = socket(static_cast<std::int32_t>(domain), SOCK_STREAM, 0);
724  if (fd == InvalidSocket()) {
725  system::ThrowAtError("socket");
726  }
727 
728  TCPSocket socket{fd};
729 #if defined(__APPLE__)
730  socket.domain_ = domain;
731 #endif // defined(__APPLE__)
732  return socket;
733 #endif // defined(xgboost_IS_MINGW)
734  }
735 
736  static TCPSocket *CreatePtr(SockDomain domain) {
737 #if defined(xgboost_IS_MINGW)
738  MingWError();
739  return nullptr;
740 #else
741  auto fd = socket(static_cast<std::int32_t>(domain), SOCK_STREAM, 0);
742  if (fd == InvalidSocket()) {
743  system::ThrowAtError("socket");
744  }
745  auto socket = new TCPSocket{fd};
746 
747 #if defined(__APPLE__)
748  socket->domain_ = domain;
749 #endif // defined(__APPLE__)
750  return socket;
751 #endif // defined(xgboost_IS_MINGW)
752  }
753 };
754 
767 [[nodiscard]] Result Connect(xgboost::StringView host, std::int32_t port, std::int32_t retry,
768  std::chrono::seconds timeout,
770 
774 [[nodiscard]] Result GetHostName(std::string *p_out);
775 
779 template <typename H>
780 Result INetNToP(H const &host, std::string *p_out) {
781  std::string &ip = *p_out;
782  switch (host->h_addrtype) {
783  case AF_INET: {
784  auto addr = reinterpret_cast<struct in_addr *>(host->h_addr_list[0]);
785  char str[INET_ADDRSTRLEN];
786  inet_ntop(AF_INET, addr, str, INET_ADDRSTRLEN);
787  ip = str;
788  break;
789  }
790  case AF_INET6: {
791  auto addr = reinterpret_cast<struct in6_addr *>(host->h_addr_list[0]);
792  char str[INET6_ADDRSTRLEN];
793  inet_ntop(AF_INET6, addr, str, INET6_ADDRSTRLEN);
794  ip = str;
795  break;
796  }
797  default: {
798  return Fail("Invalid address type.");
799  }
800  }
801  return Success();
802 }
803 } // namespace collective
804 } // namespace xgboost
805 
806 #undef xgboost_CHECK_SYS_CALL
807 
808 #if defined(xgboost_IS_MINGW)
809 #undef xgboost_IS_MINGW
810 #endif
Defines configuration macros and basic types for xgboost.
Definition: socket.h:221
SockAddrV4(sockaddr_in addr)
Definition: socket.h:226
static SockAddrV4 InaddrAny()
in_port_t Port() const
Definition: socket.h:232
sockaddr_in const & Handle() const
Definition: socket.h:243
std::string Addr() const
Definition: socket.h:234
static SockAddrV4 Loopback()
SockAddrV4()
Definition: socket.h:227
Definition: socket.h:197
static SockAddrV6 InaddrAny()
SockAddrV6()
Definition: socket.h:202
sockaddr_in6 const & Handle() const
Definition: socket.h:218
in_port_t Port() const
Definition: socket.h:207
SockAddrV6(sockaddr_in6 addr)
Definition: socket.h:201
std::string Addr() const
Definition: socket.h:209
static SockAddrV6 Loopback()
Address for TCP socket, can be either IPv4 or IPv6.
Definition: socket.h:249
bool IsV6() const
Definition: socket.h:263
auto const & V6() const
Definition: socket.h:266
bool IsV4() const
Definition: socket.h:262
auto Domain() const
Definition: socket.h:260
SockAddress(SockAddrV4 const &addr)
Definition: socket.h:258
auto const & V4() const
Definition: socket.h:265
SockAddress(SockAddrV6 const &addr)
Definition: socket.h:257
TCP socket for simple communication.
Definition: socket.h:272
Result GetSockError() const
get last error code if any
Definition: socket.h:342
Result Recv(std::string *p_str)
Receive string, format is matched with the Python socket wrapper in RABIT.
Result RecvTimeout(std::chrono::seconds timeout)
Definition: socket.h:396
HandleT const & Handle() const
Return the native socket file descriptor.
Definition: socket.h:515
TCPSocket & operator=(TCPSocket const &that)=delete
auto SendAll(void const *buf, std::size_t len)
Send data, without error then all data should be sent.
Definition: socket.h:607
Result BindHost(std::int32_t *p_out)
Bind socket to INADDR_ANY, return the port selected by the OS.
Definition: socket.h:528
static TCPSocket * CreatePtr(SockDomain domain)
Definition: socket.h:736
Result SetNoDelay()
Definition: socket.h:439
Result Shutdown()
Call shutdown on the socket.
Definition: socket.h:698
system::SocketT HandleT
Definition: socket.h:274
Result Bind(StringView ip, std::int32_t *port)
Definition: socket.h:582
TCPSocket & operator=(TCPSocket &&that) noexcept(true)
Definition: socket.h:508
auto Port() const
Definition: socket.h:562
Result SetKeepAlive()
Definition: socket.h:429
Result Accept(TCPSocket *out, SockAddress *addr)
Definition: socket.h:460
TCPSocket(TCPSocket const &that)=delete
std::size_t Send(StringView str)
Send string, format is matched with the Python socket wrapper in RABIT.
auto Domain() const -> SockDomain
Return the socket domain.
Definition: socket.h:294
Result Listen(std::int32_t backlog=16)
Listen to incoming requests. Should be called after bind.
Definition: socket.h:519
static TCPSocket Create(SockDomain domain)
Create a TCP socket on specified domain.
Definition: socket.h:718
auto Recv(void *buf, std::size_t len, std::int32_t flags=0)
receive data using the socket
Definition: socket.h:663
auto RecvAll(void *buf, std::size_t len)
Receive data, without error then all data should be received.
Definition: socket.h:626
bool NonBlocking() const
Definition: socket.h:395
bool IsClosed() const
Definition: socket.h:339
Result NonBlocking(bool non_block)
Definition: socket.h:370
TCPSocket Accept()
Accept new connection, returns a new TCP socket for the new connection.
Definition: socket.h:452
~TCPSocket()
Definition: socket.h:496
auto Send(const void *buf_, std::size_t len, std::int32_t flags=0)
Send data using the socket.
Definition: socket.h:652
bool BadSocket() const
check if anything bad happens
Definition: socket.h:358
TCPSocket(TCPSocket &&that) noexcept(true)
Definition: socket.h:506
Result SetBufSize(std::int32_t n_bytes)
Definition: socket.h:415
Result Close()
Close the socket, called automatically in destructor if the socket is not closed.
Definition: socket.h:678
void swap(xgboost::IntrusivePtr< T > &x, xgboost::IntrusivePtr< T > &y) noexcept
Definition: intrusive_ptr.h:209
Result Connect(xgboost::StringView host, std::int32_t port, std::int32_t retry, std::chrono::seconds timeout, xgboost::collective::TCPSocket *out_conn)
Connect to remote address, returns the error code if failed.
Result GetHostName(std::string *p_out)
Get the local host name.
SockAddress MakeSockAddress(StringView host, in_port_t port)
Parse host address and return a SockAddress instance. Supports IPv4 and IPv6 host.
SockDomain
Definition: socket.h:189
auto Fail(std::string msg, char const *file=__builtin_FILE(), std::int32_t line=__builtin_LINE())
Return failure.
Definition: result.h:125
void SafeColl(Result const &rc)
Result INetNToP(H const &host, std::string *p_out)
inet_ntop
Definition: socket.h:780
auto Success() noexcept(true)
Return success.
Definition: result.h:121
bool ErrorWouldBlock(std::int32_t errsv) noexcept(true)
Definition: socket.h:142
auto ThrowAtError(StringView fn_name, std::int32_t errsv=LastError())
Definition: socket.h:98
void SocketStartup()
Definition: socket.h:155
std::int32_t CloseSocket(SocketT fd)
Definition: socket.h:119
bool LastErrorWouldBlock()
Definition: socket.h:150
std::int32_t LastError()
Definition: socket.h:75
void SocketFinalize()
Definition: socket.h:168
std::int32_t ShutdownSocket(SocketT fd)
Definition: socket.h:127
collective::Result FailWithCode(std::string msg)
Definition: socket.h:84
int SocketT
Definition: socket.h:107
Core data structure for multi-target trees.
Definition: base.h:87
#define __builtin_LINE()
Definition: result.h:57
#define __builtin_FILE()
Definition: result.h:56
#define xgboost_CHECK_SYS_CALL(exp, expected)
Definition: socket.h:111
Definition: string_view.h:16
An error type that's easier to handle than throwing dmlc exception. We can record and propagate the s...
Definition: result.h:68