xgboost
socket.h
Go to the documentation of this file.
1 
4 #pragma once
5 
6 #include <cerrno> // errno, EINTR, EBADF
7 #include <climits> // HOST_NAME_MAX
8 #include <cstddef> // std::size_t
9 #include <cstdint> // std::int32_t, std::uint16_t
10 #include <cstring> // memset
11 #include <string> // std::string
12 #include <system_error> // std::error_code, std::system_category
13 #include <utility> // std::swap
14 
15 #if defined(__linux__)
16 #include <sys/ioctl.h> // for TIOCOUTQ, FIONREAD
17 #endif // defined(__linux__)
18 
19 #if defined(_WIN32)
20 // Guard the include.
21 #include <xgboost/windefs.h>
22 // Socket API
23 #include <winsock2.h>
24 #include <ws2tcpip.h>
25 
26 using in_port_t = std::uint16_t;
27 
28 #ifdef _MSC_VER
29 #pragma comment(lib, "Ws2_32.lib")
30 #endif // _MSC_VER
31 
32 #if !defined(xgboost_IS_MINGW)
33 using ssize_t = int;
34 #endif // !xgboost_IS_MINGW()
35 
36 #else // UNIX
37 
38 #include <arpa/inet.h> // inet_ntop
39 #include <fcntl.h> // fcntl, F_GETFL, O_NONBLOCK
40 #include <netinet/in.h> // sockaddr_in6, sockaddr_in, in_port_t, INET6_ADDRSTRLEN, INET_ADDRSTRLEN
41 #include <netinet/in.h> // IPPROTO_TCP
42 #include <netinet/tcp.h> // TCP_NODELAY
43 #include <sys/socket.h> // socket, SOL_SOCKET, SO_ERROR, MSG_WAITALL, recv, send, AF_INET6, AF_INET
44 #include <unistd.h> // close
45 
46 #if defined(__sun) || defined(sun)
47 #include <sys/sockio.h>
48 #endif // defined(__sun) || defined(sun)
49 
50 #endif // defined(_WIN32)
51 
52 #include "xgboost/base.h" // XGBOOST_EXPECT
53 #include "xgboost/collective/result.h" // for Result
54 #include "xgboost/logging.h" // LOG
55 #include "xgboost/string_view.h" // StringView
56 
57 #if !defined(HOST_NAME_MAX)
58 #define HOST_NAME_MAX 256 // macos
59 #endif
60 
61 namespace xgboost {
62 
63 #if defined(xgboost_IS_MINGW)
64 // see the dummy implementation of `poll` in rabit for more info.
65 inline void MingWError() { LOG(FATAL) << "Distributed training on mingw is not supported."; }
66 #endif // defined(xgboost_IS_MINGW)
67 
68 namespace system {
69 inline std::int32_t LastError() {
70 #if defined(_WIN32)
71  return WSAGetLastError();
72 #else
73  int errsv = errno;
74  return errsv;
75 #endif
76 }
77 
78 [[nodiscard]] inline collective::Result FailWithCode(std::string msg) {
79  return collective::Fail(std::move(msg), std::error_code{LastError(), std::system_category()});
80 }
81 
82 #if defined(__GLIBC__)
83 inline auto ThrowAtError(StringView fn_name, std::int32_t errsv = LastError(),
84  std::int32_t line = __builtin_LINE(),
85  char const *file = __builtin_FILE()) {
86  auto err = std::error_code{errsv, std::system_category()};
87  LOG(FATAL) << "\n"
88  << file << "(" << line << "): Failed to call `" << fn_name << "`: " << err.message()
89  << std::endl;
90 }
91 #else
92 inline auto ThrowAtError(StringView fn_name, std::int32_t errsv = LastError()) {
93  auto err = std::error_code{errsv, std::system_category()};
94  LOG(FATAL) << "Failed to call `" << fn_name << "`: " << err.message() << std::endl;
95 }
96 #endif // defined(__GLIBC__)
97 
98 #if defined(_WIN32)
99 using SocketT = SOCKET;
100 #else
101 using SocketT = int;
102 #define INVALID_SOCKET -1
103 #endif // defined(_WIN32)
104 
105 #if !defined(xgboost_CHECK_SYS_CALL)
106 #define xgboost_CHECK_SYS_CALL(exp, expected) \
107  do { \
108  if (XGBOOST_EXPECT((exp) != (expected), false)) { \
109  ::xgboost::system::ThrowAtError(#exp); \
110  } \
111  } while (false)
112 #endif // !defined(xgboost_CHECK_SYS_CALL)
113 
114 inline std::int32_t CloseSocket(SocketT fd) {
115 #if defined(_WIN32)
116  return closesocket(fd);
117 #else
118  return close(fd);
119 #endif
120 }
121 
122 inline std::int32_t ShutdownSocket(SocketT fd) {
123 #if defined(_WIN32)
124  auto rc = shutdown(fd, SD_BOTH);
125  if (rc != 0 && LastError() == WSANOTINITIALISED) {
126  return 0;
127  }
128 #else
129  auto rc = shutdown(fd, SHUT_RDWR);
130  if (rc != 0 && LastError() == ENOTCONN) {
131  return 0;
132  }
133 #endif
134  return rc;
135 }
136 
137 inline bool ErrorWouldBlock(std::int32_t errsv) noexcept(true) {
138 #ifdef _WIN32
139  return errsv == WSAEWOULDBLOCK;
140 #else
141  return errsv == EAGAIN || errsv == EWOULDBLOCK || errsv == EINPROGRESS;
142 #endif // _WIN32
143 }
144 
145 inline bool LastErrorWouldBlock() {
146  int errsv = LastError();
147  return ErrorWouldBlock(errsv);
148 }
149 
150 inline void SocketStartup() {
151 #if defined(_WIN32)
152  WSADATA wsa_data;
153  if (WSAStartup(MAKEWORD(2, 2), &wsa_data) == -1) {
154  ThrowAtError("WSAStartup");
155  }
156  if (LOBYTE(wsa_data.wVersion) != 2 || HIBYTE(wsa_data.wVersion) != 2) {
157  WSACleanup();
158  LOG(FATAL) << "Could not find a usable version of Winsock.dll";
159  }
160 #endif // defined(_WIN32)
161 }
162 
163 inline void SocketFinalize() {
164 #if defined(_WIN32)
165  WSACleanup();
166 #endif // defined(_WIN32)
167 }
168 
169 #if defined(_WIN32) && defined(xgboost_IS_MINGW)
170 // dummy definition for old mysys32.
171 inline const char *inet_ntop(int, const void *, char *, socklen_t) { // NOLINT
172  MingWError();
173  return nullptr;
174 }
175 #else
176 using ::inet_ntop;
177 #endif // defined(_WIN32) && defined(xgboost_IS_MINGW)
178 
179 } // namespace system
180 
181 namespace collective {
182 class SockAddress;
183 
184 enum class SockDomain : std::int32_t { kV4 = AF_INET, kV6 = AF_INET6 };
185 
190 SockAddress MakeSockAddress(StringView host, in_port_t port);
191 
192 class SockAddrV6 {
193  sockaddr_in6 addr_;
194 
195  public:
196  explicit SockAddrV6(sockaddr_in6 addr) : addr_{addr} {}
197  SockAddrV6() { std::memset(&addr_, '\0', sizeof(addr_)); }
198 
201 
202  in_port_t Port() const { return ntohs(addr_.sin6_port); }
203 
204  std::string Addr() const {
205  char buf[INET6_ADDRSTRLEN];
206  auto const *s = system::inet_ntop(static_cast<std::int32_t>(SockDomain::kV6), &addr_.sin6_addr,
207  buf, INET6_ADDRSTRLEN);
208  if (s == nullptr) {
209  system::ThrowAtError("inet_ntop");
210  }
211  return {buf};
212  }
213  sockaddr_in6 const &Handle() const { return addr_; }
214 };
215 
216 class SockAddrV4 {
217  private:
218  sockaddr_in addr_;
219 
220  public:
221  explicit SockAddrV4(sockaddr_in addr) : addr_{addr} {}
222  SockAddrV4() { std::memset(&addr_, '\0', sizeof(addr_)); }
223 
226 
227  [[nodiscard]] in_port_t Port() const { return ntohs(addr_.sin_port); }
228 
229  [[nodiscard]] std::string Addr() const {
230  char buf[INET_ADDRSTRLEN];
231  auto const *s = system::inet_ntop(static_cast<std::int32_t>(SockDomain::kV4), &addr_.sin_addr,
232  buf, INET_ADDRSTRLEN);
233  if (s == nullptr) {
234  system::ThrowAtError("inet_ntop");
235  }
236  return {buf};
237  }
238  [[nodiscard]] sockaddr_in const &Handle() const { return addr_; }
239 };
240 
244 class SockAddress {
245  private:
246  SockAddrV6 v6_;
247  SockAddrV4 v4_;
248  SockDomain domain_{SockDomain::kV4};
249 
250  public:
251  SockAddress() = default;
252  explicit SockAddress(SockAddrV6 const &addr) : v6_{addr}, domain_{SockDomain::kV6} {}
253  explicit SockAddress(SockAddrV4 const &addr) : v4_{addr} {}
254 
255  [[nodiscard]] auto Domain() const { return domain_; }
256 
257  [[nodiscard]] bool IsV4() const { return Domain() == SockDomain::kV4; }
258  [[nodiscard]] bool IsV6() const { return !IsV4(); }
259 
260  [[nodiscard]] auto const &V4() const { return v4_; }
261  [[nodiscard]] auto const &V6() const { return v6_; }
262 };
263 
267 class TCPSocket {
268  public:
270 
271  private:
272  HandleT handle_{InvalidSocket()};
273  bool non_blocking_{false};
274  // There's reliable no way to extract domain from a socket without first binding that
275  // socket on macos.
276 #if defined(__APPLE__)
277  SockDomain domain_{SockDomain::kV4};
278 #endif
279 
280  constexpr static HandleT InvalidSocket() { return INVALID_SOCKET; }
281 
282  explicit TCPSocket(HandleT newfd) : handle_{newfd} {}
283 
284  public:
285  TCPSocket() = default;
289  [[nodiscard]] auto Domain() const -> SockDomain {
290  auto ret_iafamily = [](std::int32_t domain) {
291  switch (domain) {
292  case AF_INET:
293  return SockDomain::kV4;
294  case AF_INET6:
295  return SockDomain::kV6;
296  default: {
297  LOG(FATAL) << "Unknown IA family.";
298  }
299  }
300  return SockDomain::kV4;
301  };
302 
303 #if defined(_WIN32)
304  WSAPROTOCOL_INFOA info;
305  socklen_t len = sizeof(info);
307  getsockopt(handle_, SOL_SOCKET, SO_PROTOCOL_INFO, reinterpret_cast<char *>(&info), &len),
308  0);
309  return ret_iafamily(info.iAddressFamily);
310 #elif defined(__APPLE__)
311  return domain_;
312 #elif defined(__unix__)
313 #ifndef __PASE__
314  std::int32_t domain;
315  socklen_t len = sizeof(domain);
317  getsockopt(this->Handle(), SOL_SOCKET, SO_DOMAIN, reinterpret_cast<char *>(&domain), &len),
318  0);
319  return ret_iafamily(domain);
320 #else
321  struct sockaddr sa;
322  socklen_t sizeofsa = sizeof(sa);
323  xgboost_CHECK_SYS_CALL(getsockname(handle_, &sa, &sizeofsa), 0);
324  if (sizeofsa < sizeof(uchar_t) * 2) {
325  return ret_iafamily(AF_INET);
326  }
327  return ret_iafamily(sa.sa_family);
328 #endif // __PASE__
329 #else
330  LOG(FATAL) << "Unknown platform.";
331  return ret_iafamily(AF_INET);
332 #endif // platforms
333  }
334 
335  [[nodiscard]] bool IsClosed() const { return handle_ == InvalidSocket(); }
336 
338  [[nodiscard]] Result GetSockError() const {
339  std::int32_t optval = 0;
340  socklen_t len = sizeof(optval);
341  auto ret = getsockopt(handle_, SOL_SOCKET, SO_ERROR, reinterpret_cast<char *>(&optval), &len);
342  if (ret != 0) {
343  auto errc = std::error_code{system::LastError(), std::system_category()};
344  return Fail("Failed to retrieve socket error.", std::move(errc));
345  }
346  if (optval != 0) {
347  auto errc = std::error_code{optval, std::system_category()};
348  return Fail("Socket error.", std::move(errc));
349  }
350  return Success();
351  }
352 
354  [[nodiscard]] bool BadSocket() const {
355  if (IsClosed()) {
356  return true;
357  }
358  auto err = GetSockError();
359  if (err.Code() == std::error_code{EBADF, std::system_category()} || // NOLINT
360  err.Code() == std::error_code{EINTR, std::system_category()}) { // NOLINT
361  return true;
362  }
363  return false;
364  }
365 
366  [[nodiscard]] Result NonBlocking(bool non_block) {
367 #if defined(_WIN32)
368  u_long mode = non_block ? 1 : 0;
369  if (ioctlsocket(handle_, FIONBIO, &mode) != NO_ERROR) {
370  return system::FailWithCode("Failed to set socket to non-blocking.");
371  }
372 #else
373  std::int32_t flag = fcntl(handle_, F_GETFL, 0);
374  auto rc = flag;
375  if (rc == -1) {
376  return system::FailWithCode("Failed to get socket flag.");
377  }
378  if (non_block) {
379  flag |= O_NONBLOCK;
380  } else {
381  flag &= ~O_NONBLOCK;
382  }
383  rc = fcntl(handle_, F_SETFL, flag);
384  if (rc == -1) {
385  return system::FailWithCode("Failed to set socket to non-blocking.");
386  }
387 #endif // _WIN32
388  non_blocking_ = non_block;
389  return Success();
390  }
391  [[nodiscard]] bool NonBlocking() const { return non_blocking_; }
392  [[nodiscard]] Result RecvTimeout(std::chrono::seconds timeout) {
393  // https://stackoverflow.com/questions/2876024/linux-is-there-a-read-or-recv-from-socket-with-timeout
394 #if defined(_WIN32)
395  DWORD tv = timeout.count() * 1000;
396  auto rc =
397  setsockopt(Handle(), SOL_SOCKET, SO_RCVTIMEO, reinterpret_cast<char *>(&tv), sizeof(tv));
398 #else
399  struct timeval tv;
400  tv.tv_sec = timeout.count();
401  tv.tv_usec = 0;
402  auto rc = setsockopt(Handle(), SOL_SOCKET, SO_RCVTIMEO, reinterpret_cast<char const *>(&tv),
403  sizeof(tv));
404 #endif
405  if (rc != 0) {
406  return system::FailWithCode("Failed to set timeout on recv.");
407  }
408  return Success();
409  }
410 
411  [[nodiscard]] Result SetBufSize(std::int32_t n_bytes) {
412  auto rc = setsockopt(this->Handle(), SOL_SOCKET, SO_SNDBUF, reinterpret_cast<char *>(&n_bytes),
413  sizeof(n_bytes));
414  if (rc != 0) {
415  return system::FailWithCode("Failed to set send buffer size.");
416  }
417  rc = setsockopt(this->Handle(), SOL_SOCKET, SO_RCVBUF, reinterpret_cast<char *>(&n_bytes),
418  sizeof(n_bytes));
419  if (rc != 0) {
420  return system::FailWithCode("Failed to set recv buffer size.");
421  }
422  return Success();
423  }
424 
425  [[nodiscard]] Result SendBufSize(std::int32_t *n_bytes) {
426  socklen_t optlen;
427  auto rc = getsockopt(this->Handle(), SOL_SOCKET, SO_SNDBUF, reinterpret_cast<char *>(n_bytes),
428  &optlen);
429  if (rc != 0 || optlen != sizeof(std::int32_t)) {
430  return system::FailWithCode("getsockopt");
431  }
432  return Success();
433  }
434  [[nodiscard]] Result RecvBufSize(std::int32_t *n_bytes) {
435  socklen_t optlen;
436  auto rc = getsockopt(this->Handle(), SOL_SOCKET, SO_RCVBUF, reinterpret_cast<char *>(n_bytes),
437  &optlen);
438  if (rc != 0 || optlen != sizeof(std::int32_t)) {
439  return system::FailWithCode("getsockopt");
440  }
441  return Success();
442  }
443 #if defined(__linux__)
444  [[nodiscard]] Result PendingSendSize(std::int32_t *n_bytes) const {
445  return ioctl(this->Handle(), TIOCOUTQ, n_bytes) == 0 ? Success()
446  : system::FailWithCode("ioctl");
447  }
448  [[nodiscard]] Result PendingRecvSize(std::int32_t *n_bytes) const {
449  return ioctl(this->Handle(), FIONREAD, n_bytes) == 0 ? Success()
450  : system::FailWithCode("ioctl");
451  }
452 #endif // defined(__linux__)
453 
454  [[nodiscard]] Result SetKeepAlive() {
455  std::int32_t keepalive = 1;
456  auto rc = setsockopt(handle_, SOL_SOCKET, SO_KEEPALIVE, reinterpret_cast<char *>(&keepalive),
457  sizeof(keepalive));
458  if (rc != 0) {
459  return system::FailWithCode("Failed to set TCP keeaplive.");
460  }
461  return Success();
462  }
463 
464  [[nodiscard]] Result SetNoDelay(std::int32_t no_delay = 1) {
465  auto rc = setsockopt(handle_, IPPROTO_TCP, TCP_NODELAY, reinterpret_cast<char *>(&no_delay),
466  sizeof(no_delay));
467  if (rc != 0) {
468  return system::FailWithCode("Failed to set TCP no delay.");
469  }
470  return Success();
471  }
472 
477  SockAddress addr;
478  TCPSocket newsock;
479  auto rc = this->Accept(&newsock, &addr);
480  SafeColl(rc);
481  return newsock;
482  }
483 
484  [[nodiscard]] Result Accept(TCPSocket *out, SockAddress *addr) {
485 #if defined(_WIN32)
486  auto interrupt = WSAEINTR;
487 #else
488  auto interrupt = EINTR;
489 #endif
490  if (this->Domain() == SockDomain::kV4) {
491  struct sockaddr_in caddr;
492  socklen_t caddr_len = sizeof(caddr);
493  HandleT newfd = accept(Handle(), reinterpret_cast<sockaddr *>(&caddr), &caddr_len);
494  if (newfd == InvalidSocket() && system::LastError() != interrupt) {
495  return system::FailWithCode("Failed to accept.");
496  }
497  *addr = SockAddress{SockAddrV4{caddr}};
498  *out = TCPSocket{newfd};
499  } else {
500  struct sockaddr_in6 caddr;
501  socklen_t caddr_len = sizeof(caddr);
502  HandleT newfd = accept(Handle(), reinterpret_cast<sockaddr *>(&caddr), &caddr_len);
503  if (newfd == InvalidSocket() && system::LastError() != interrupt) {
504  return system::FailWithCode("Failed to accept.");
505  }
506  *addr = SockAddress{SockAddrV6{caddr}};
507  *out = TCPSocket{newfd};
508  }
509  // On MacOS, this is automatically set to async socket if the parent socket is async
510  // We make sure all socket are blocking by default.
511  //
512  // On Windows, a closed socket is returned during shutdown. We guard against it when
513  // setting non-blocking.
514  if (!out->IsClosed()) {
515  return out->NonBlocking(false);
516  }
517  return Success();
518  }
519 
521  if (!IsClosed()) {
522  auto rc = this->Close();
523  if (!rc.OK()) {
524  LOG(WARNING) << rc.Report();
525  }
526  }
527  }
528 
529  TCPSocket(TCPSocket const &that) = delete;
530  TCPSocket(TCPSocket &&that) noexcept(true) { std::swap(this->handle_, that.handle_); }
531  TCPSocket &operator=(TCPSocket const &that) = delete;
532  TCPSocket &operator=(TCPSocket &&that) noexcept(true) {
533  std::swap(this->handle_, that.handle_);
534  return *this;
535  }
539  [[nodiscard]] HandleT const &Handle() const { return handle_; }
545  [[nodiscard]] Result Listen(std::int32_t backlog = 256);
549  [[nodiscard]] Result BindHost(std::int32_t* p_out) {
550  // Use int32 instead of in_port_t for consistency. We take port as parameter from
551  // users using other languages, the port is usually stored and passed around as int.
552  if (Domain() == SockDomain::kV6) {
553  auto addr = SockAddrV6::InaddrAny();
554  auto handle = reinterpret_cast<sockaddr const *>(&addr.Handle());
555  if (bind(handle_, handle, sizeof(std::remove_reference_t<decltype(addr.Handle())>)) != 0) {
556  return system::FailWithCode("bind failed.");
557  }
558 
559  sockaddr_in6 res_addr;
560  socklen_t addrlen = sizeof(res_addr);
561  if (getsockname(handle_, reinterpret_cast<sockaddr *>(&res_addr), &addrlen) != 0) {
562  return system::FailWithCode("getsockname failed.");
563  }
564  *p_out = ntohs(res_addr.sin6_port);
565  } else {
566  auto addr = SockAddrV4::InaddrAny();
567  auto handle = reinterpret_cast<sockaddr const *>(&addr.Handle());
568  if (bind(handle_, handle, sizeof(std::remove_reference_t<decltype(addr.Handle())>)) != 0) {
569  return system::FailWithCode("bind failed.");
570  }
571 
572  sockaddr_in res_addr;
573  socklen_t addrlen = sizeof(res_addr);
574  if (getsockname(handle_, reinterpret_cast<sockaddr *>(&res_addr), &addrlen) != 0) {
575  return system::FailWithCode("getsockname failed.");
576  }
577  *p_out = ntohs(res_addr.sin_port);
578  }
579 
580  return Success();
581  }
582 
583  [[nodiscard]] auto Port() const {
584  if (this->Domain() == SockDomain::kV4) {
585  sockaddr_in res_addr;
586  socklen_t addrlen = sizeof(res_addr);
587  auto code = getsockname(handle_, reinterpret_cast<sockaddr *>(&res_addr), &addrlen);
588  if (code != 0) {
589  return std::make_pair(system::FailWithCode("getsockname"), std::int32_t{0});
590  }
591  return std::make_pair(Success(), std::int32_t{ntohs(res_addr.sin_port)});
592  } else {
593  sockaddr_in6 res_addr;
594  socklen_t addrlen = sizeof(res_addr);
595  auto code = getsockname(handle_, reinterpret_cast<sockaddr *>(&res_addr), &addrlen);
596  if (code != 0) {
597  return std::make_pair(system::FailWithCode("getsockname"), std::int32_t{0});
598  }
599  return std::make_pair(Success(), std::int32_t{ntohs(res_addr.sin6_port)});
600  }
601  }
608  [[nodiscard]] Result Bind(StringView ip, std::int32_t *port) {
609  // bind socket handle_ to ip
610  auto addr = MakeSockAddress(ip, *port);
611  std::int32_t errc{0};
612  if (addr.IsV4()) {
613  auto handle = reinterpret_cast<sockaddr const *>(&addr.V4().Handle());
614  errc = bind(handle_, handle, sizeof(std::remove_reference_t<decltype(addr.V4().Handle())>));
615  } else {
616  auto handle = reinterpret_cast<sockaddr const *>(&addr.V6().Handle());
617  errc = bind(handle_, handle, sizeof(std::remove_reference_t<decltype(addr.V6().Handle())>));
618  }
619  if (errc != 0) {
620  return system::FailWithCode("Failed to bind socket.");
621  }
622  auto [rc, new_port] = this->Port();
623  if (!rc.OK()) {
624  return std::move(rc);
625  }
626  if (*port == 0) {
627  *port = new_port;
628  return Success();
629  }
630  if (*port != new_port) {
631  return Fail("Got an invalid port from bind.");
632  }
633  return Success();
634  }
635 
639  [[nodiscard]] Result SendAll(void const *buf, std::size_t len, std::size_t *n_sent) {
640  char const *_buf = reinterpret_cast<const char *>(buf);
641  std::size_t &ndone = *n_sent;
642  ndone = 0;
643  while (ndone < len) {
644  ssize_t ret = send(handle_, _buf, len - ndone, 0);
645  if (ret == -1) {
647  return Success();
648  }
649  return system::FailWithCode("send");
650  }
651  _buf += ret;
652  ndone += ret;
653  }
654  return Success();
655  }
659  [[nodiscard]] Result RecvAll(void *buf, std::size_t len, std::size_t *n_recv) {
660  char *_buf = reinterpret_cast<char *>(buf);
661  std::size_t &ndone = *n_recv;
662  ndone = 0;
663  while (ndone < len) {
664  ssize_t ret = recv(handle_, _buf, len - ndone, MSG_WAITALL);
665  if (ret == -1) {
667  return Success();
668  }
669  return system::FailWithCode("recv");
670  }
671  if (ret == 0) {
672  return Success();
673  }
674  _buf += ret;
675  ndone += ret;
676  }
677  return Success();
678  }
686  auto Send(const void *buf_, std::size_t len, std::int32_t flags = 0) {
687  const char *buf = reinterpret_cast<const char *>(buf_);
688  return send(handle_, buf, len, flags);
689  }
697  auto Recv(void *buf, std::size_t len, std::int32_t flags = 0) {
698  char *_buf = static_cast<char *>(buf);
699  // See https://github.com/llvm/llvm-project/issues/104241 for skipped tidy analysis
700  // NOLINTBEGIN(clang-analyzer-unix.BlockInCriticalSection)
701  return recv(handle_, _buf, len, flags);
702  // NOLINTEND(clang-analyzer-unix.BlockInCriticalSection)
703  }
707  std::size_t Send(StringView str);
711  [[nodiscard]] Result Recv(std::string *p_str);
715  [[nodiscard]] Result Close() {
716  if (InvalidSocket() != handle_) {
717  auto rc = system::CloseSocket(handle_);
718 #if defined(_WIN32)
719  // it's possible that we close TCP sockets after finalizing WSA due to detached thread.
720  if (rc != 0 && system::LastError() != WSANOTINITIALISED) {
721  return system::FailWithCode("Failed to close the socket.");
722  }
723 #else
724  if (rc != 0) {
725  return system::FailWithCode("Failed to close the socket.");
726  }
727 #endif
728  handle_ = InvalidSocket();
729  }
730  return Success();
731  }
735  [[nodiscard]] Result Shutdown() {
736  if (this->IsClosed()) {
737  return Success();
738  }
739  auto rc = system::ShutdownSocket(this->Handle());
740 #if defined(_WIN32)
741  // Windows cannot shutdown a socket if it's not connected.
742  if (rc == -1 && system::LastError() == WSAENOTCONN) {
743  return Success();
744  }
745 #endif
746  if (rc != 0) {
747  return system::FailWithCode("Failed to shutdown socket.");
748  }
749  return Success();
750  }
751 
755  static TCPSocket Create(SockDomain domain) {
756 #if defined(xgboost_IS_MINGW)
757  MingWError();
758  return {};
759 #else
760  auto fd = socket(static_cast<std::int32_t>(domain), SOCK_STREAM, 0);
761  if (fd == InvalidSocket()) {
762  system::ThrowAtError("socket");
763  }
764 
765  TCPSocket socket{fd};
766 #if defined(__APPLE__)
767  socket.domain_ = domain;
768 #endif // defined(__APPLE__)
769  return socket;
770 #endif // defined(xgboost_IS_MINGW)
771  }
772 
773  static TCPSocket *CreatePtr(SockDomain domain) {
774 #if defined(xgboost_IS_MINGW)
775  MingWError();
776  return nullptr;
777 #else
778  auto fd = socket(static_cast<std::int32_t>(domain), SOCK_STREAM, 0);
779  if (fd == InvalidSocket()) {
780  system::ThrowAtError("socket");
781  }
782  auto socket = new TCPSocket{fd};
783 
784 #if defined(__APPLE__)
785  socket->domain_ = domain;
786 #endif // defined(__APPLE__)
787  return socket;
788 #endif // defined(xgboost_IS_MINGW)
789  }
790 };
791 
804 [[nodiscard]] Result Connect(xgboost::StringView host, std::int32_t port, std::int32_t retry,
805  std::chrono::seconds timeout,
807 
811 [[nodiscard]] Result GetHostName(std::string *p_out);
812 
816 template <typename H>
817 Result INetNToP(H const &host, std::string *p_out) {
818  std::string &ip = *p_out;
819  switch (host->h_addrtype) {
820  case AF_INET: {
821  auto addr = reinterpret_cast<struct in_addr *>(host->h_addr_list[0]);
822  char str[INET_ADDRSTRLEN];
823  inet_ntop(AF_INET, addr, str, INET_ADDRSTRLEN);
824  ip = str;
825  break;
826  }
827  case AF_INET6: {
828  auto addr = reinterpret_cast<struct in6_addr *>(host->h_addr_list[0]);
829  char str[INET6_ADDRSTRLEN];
830  inet_ntop(AF_INET6, addr, str, INET6_ADDRSTRLEN);
831  ip = str;
832  break;
833  }
834  default: {
835  return Fail("Invalid address type.");
836  }
837  }
838  return Success();
839 }
840 } // namespace collective
841 } // namespace xgboost
842 
843 #undef xgboost_CHECK_SYS_CALL
Defines configuration macros and basic types for xgboost.
Definition: socket.h:216
SockAddrV4(sockaddr_in addr)
Definition: socket.h:221
static SockAddrV4 InaddrAny()
in_port_t Port() const
Definition: socket.h:227
sockaddr_in const & Handle() const
Definition: socket.h:238
std::string Addr() const
Definition: socket.h:229
static SockAddrV4 Loopback()
SockAddrV4()
Definition: socket.h:222
Definition: socket.h:192
static SockAddrV6 InaddrAny()
SockAddrV6()
Definition: socket.h:197
sockaddr_in6 const & Handle() const
Definition: socket.h:213
in_port_t Port() const
Definition: socket.h:202
SockAddrV6(sockaddr_in6 addr)
Definition: socket.h:196
std::string Addr() const
Definition: socket.h:204
static SockAddrV6 Loopback()
Address for TCP socket, can be either IPv4 or IPv6.
Definition: socket.h:244
bool IsV6() const
Definition: socket.h:258
auto const & V6() const
Definition: socket.h:261
bool IsV4() const
Definition: socket.h:257
auto Domain() const
Definition: socket.h:255
SockAddress(SockAddrV4 const &addr)
Definition: socket.h:253
auto const & V4() const
Definition: socket.h:260
SockAddress(SockAddrV6 const &addr)
Definition: socket.h:252
TCP socket for simple communication.
Definition: socket.h:267
Result GetSockError() const
get last error code if any
Definition: socket.h:338
Result Recv(std::string *p_str)
Receive string, format is matched with the Python socket wrapper in RABIT.
Result RecvTimeout(std::chrono::seconds timeout)
Definition: socket.h:392
HandleT const & Handle() const
Return the native socket file descriptor.
Definition: socket.h:539
Result SetNoDelay(std::int32_t no_delay=1)
Definition: socket.h:464
TCPSocket & operator=(TCPSocket const &that)=delete
Result BindHost(std::int32_t *p_out)
Bind socket to INADDR_ANY, return the port selected by the OS.
Definition: socket.h:549
Result Listen(std::int32_t backlog=256)
Listen to incoming requests. Should be called after bind.
static TCPSocket * CreatePtr(SockDomain domain)
Definition: socket.h:773
Result Shutdown()
Call shutdown on the socket.
Definition: socket.h:735
system::SocketT HandleT
Definition: socket.h:269
Result Bind(StringView ip, std::int32_t *port)
Bind the socket to the address.
Definition: socket.h:608
TCPSocket & operator=(TCPSocket &&that) noexcept(true)
Definition: socket.h:532
auto Port() const
Definition: socket.h:583
Result SetKeepAlive()
Definition: socket.h:454
Result Accept(TCPSocket *out, SockAddress *addr)
Definition: socket.h:484
TCPSocket(TCPSocket const &that)=delete
Result RecvBufSize(std::int32_t *n_bytes)
Definition: socket.h:434
Result SendBufSize(std::int32_t *n_bytes)
Definition: socket.h:425
std::size_t Send(StringView str)
Send string, format is matched with the Python socket wrapper in RABIT.
auto Domain() const -> SockDomain
Return the socket domain.
Definition: socket.h:289
static TCPSocket Create(SockDomain domain)
Create a TCP socket on specified domain.
Definition: socket.h:755
Result RecvAll(void *buf, std::size_t len, std::size_t *n_recv)
Receive data, without error then all data should be received.
Definition: socket.h:659
auto Recv(void *buf, std::size_t len, std::int32_t flags=0)
receive data using the socket
Definition: socket.h:697
bool NonBlocking() const
Definition: socket.h:391
bool IsClosed() const
Definition: socket.h:335
Result NonBlocking(bool non_block)
Definition: socket.h:366
TCPSocket Accept()
Accept new connection, returns a new TCP socket for the new connection.
Definition: socket.h:476
~TCPSocket()
Definition: socket.h:520
auto Send(const void *buf_, std::size_t len, std::int32_t flags=0)
Send data using the socket.
Definition: socket.h:686
Result SendAll(void const *buf, std::size_t len, std::size_t *n_sent)
Send data, without error then all data should be sent.
Definition: socket.h:639
bool BadSocket() const
check if anything bad happens
Definition: socket.h:354
TCPSocket(TCPSocket &&that) noexcept(true)
Definition: socket.h:530
Result SetBufSize(std::int32_t n_bytes)
Definition: socket.h:411
Result Close()
Close the socket, called automatically in destructor if the socket is not closed.
Definition: socket.h:715
void swap(xgboost::IntrusivePtr< T > &x, xgboost::IntrusivePtr< T > &y) noexcept
Definition: intrusive_ptr.h:209
Result Connect(xgboost::StringView host, std::int32_t port, std::int32_t retry, std::chrono::seconds timeout, xgboost::collective::TCPSocket *out_conn)
Connect to remote address, returns the error code if failed.
Result GetHostName(std::string *p_out)
Get the local host name.
SockAddress MakeSockAddress(StringView host, in_port_t port)
Parse host address and return a SockAddress instance. Supports IPv4 and IPv6 host.
SockDomain
Definition: socket.h:184
auto Fail(std::string msg, char const *file=__builtin_FILE(), std::int32_t line=__builtin_LINE())
Return failure.
Definition: result.h:124
void SafeColl(Result const &rc)
Result INetNToP(H const &host, std::string *p_out)
inet_ntop
Definition: socket.h:817
auto Success() noexcept(true)
Return success.
Definition: result.h:120
bool ErrorWouldBlock(std::int32_t errsv) noexcept(true)
Definition: socket.h:137
auto ThrowAtError(StringView fn_name, std::int32_t errsv=LastError())
Definition: socket.h:92
void SocketStartup()
Definition: socket.h:150
std::int32_t CloseSocket(SocketT fd)
Definition: socket.h:114
bool LastErrorWouldBlock()
Definition: socket.h:145
std::int32_t LastError()
Definition: socket.h:69
void SocketFinalize()
Definition: socket.h:163
std::int32_t ShutdownSocket(SocketT fd)
Definition: socket.h:122
collective::Result FailWithCode(std::string msg)
Definition: socket.h:78
int SocketT
Definition: socket.h:101
Core data structure for multi-target trees.
Definition: base.h:89
int SOCKET
Definition: poll_utils.h:40
#define __builtin_LINE()
Definition: result.h:57
#define __builtin_FILE()
Definition: result.h:56
#define INVALID_SOCKET
Definition: socket.h:102
#define xgboost_CHECK_SYS_CALL(exp, expected)
Definition: socket.h:106
Definition: string_view.h:16
An error type that's easier to handle than throwing dmlc exception. We can record and propagate the s...
Definition: result.h:67