xgboost
socket.h
Go to the documentation of this file.
1 
4 #pragma once
5 
6 #if !defined(NOMINMAX) && defined(_WIN32)
7 #define NOMINMAX
8 #endif // !defined(NOMINMAX)
9 
10 #include <cerrno> // errno, EINTR, EBADF
11 #include <climits> // HOST_NAME_MAX
12 #include <cstddef> // std::size_t
13 #include <cstdint> // std::int32_t, std::uint16_t
14 #include <cstring> // memset
15 #include <string> // std::string
16 #include <system_error> // std::error_code, std::system_category
17 #include <utility> // std::swap
18 
19 #if defined(__linux__)
20 #include <sys/ioctl.h> // for TIOCOUTQ, FIONREAD
21 #endif // defined(__linux__)
22 
23 #if !defined(xgboost_IS_MINGW)
24 
25 #if defined(__MINGW32__)
26 #define xgboost_IS_MINGW 1
27 #endif // defined(__MINGW32__)
28 
29 #endif // xgboost_IS_MINGW
30 
31 #if defined(_WIN32)
32 
33 #include <winsock2.h>
34 #include <ws2tcpip.h>
35 
36 using in_port_t = std::uint16_t;
37 
38 #ifdef _MSC_VER
39 #pragma comment(lib, "Ws2_32.lib")
40 #endif // _MSC_VER
41 
42 #if !defined(xgboost_IS_MINGW)
43 using ssize_t = int;
44 #endif // !xgboost_IS_MINGW()
45 
46 #else // UNIX
47 
48 #include <arpa/inet.h> // inet_ntop
49 #include <fcntl.h> // fcntl, F_GETFL, O_NONBLOCK
50 #include <netinet/in.h> // sockaddr_in6, sockaddr_in, in_port_t, INET6_ADDRSTRLEN, INET_ADDRSTRLEN
51 #include <netinet/in.h> // IPPROTO_TCP
52 #include <netinet/tcp.h> // TCP_NODELAY
53 #include <sys/socket.h> // socket, SOL_SOCKET, SO_ERROR, MSG_WAITALL, recv, send, AF_INET6, AF_INET
54 #include <unistd.h> // close
55 
56 #if defined(__sun) || defined(sun)
57 #include <sys/sockio.h>
58 #endif // defined(__sun) || defined(sun)
59 
60 #endif // defined(_WIN32)
61 
62 #include "xgboost/base.h" // XGBOOST_EXPECT
63 #include "xgboost/collective/result.h" // for Result
64 #include "xgboost/logging.h" // LOG
65 #include "xgboost/string_view.h" // StringView
66 
67 #if !defined(HOST_NAME_MAX)
68 #define HOST_NAME_MAX 256 // macos
69 #endif
70 
71 namespace xgboost {
72 
73 #if defined(xgboost_IS_MINGW)
74 // see the dummy implementation of `poll` in rabit for more info.
75 inline void MingWError() { LOG(FATAL) << "Distributed training on mingw is not supported."; }
76 #endif // defined(xgboost_IS_MINGW)
77 
78 namespace system {
79 inline std::int32_t LastError() {
80 #if defined(_WIN32)
81  return WSAGetLastError();
82 #else
83  int errsv = errno;
84  return errsv;
85 #endif
86 }
87 
88 [[nodiscard]] inline collective::Result FailWithCode(std::string msg) {
89  return collective::Fail(std::move(msg), std::error_code{LastError(), std::system_category()});
90 }
91 
92 #if defined(__GLIBC__)
93 inline auto ThrowAtError(StringView fn_name, std::int32_t errsv = LastError(),
94  std::int32_t line = __builtin_LINE(),
95  char const *file = __builtin_FILE()) {
96  auto err = std::error_code{errsv, std::system_category()};
97  LOG(FATAL) << "\n"
98  << file << "(" << line << "): Failed to call `" << fn_name << "`: " << err.message()
99  << std::endl;
100 }
101 #else
102 inline auto ThrowAtError(StringView fn_name, std::int32_t errsv = LastError()) {
103  auto err = std::error_code{errsv, std::system_category()};
104  LOG(FATAL) << "Failed to call `" << fn_name << "`: " << err.message() << std::endl;
105 }
106 #endif // defined(__GLIBC__)
107 
108 #if defined(_WIN32)
109 using SocketT = SOCKET;
110 #else
111 using SocketT = int;
112 #endif // defined(_WIN32)
113 
114 #if !defined(xgboost_CHECK_SYS_CALL)
115 #define xgboost_CHECK_SYS_CALL(exp, expected) \
116  do { \
117  if (XGBOOST_EXPECT((exp) != (expected), false)) { \
118  ::xgboost::system::ThrowAtError(#exp); \
119  } \
120  } while (false)
121 #endif // !defined(xgboost_CHECK_SYS_CALL)
122 
123 inline std::int32_t CloseSocket(SocketT fd) {
124 #if defined(_WIN32)
125  return closesocket(fd);
126 #else
127  return close(fd);
128 #endif
129 }
130 
131 inline std::int32_t ShutdownSocket(SocketT fd) {
132 #if defined(_WIN32)
133  auto rc = shutdown(fd, SD_BOTH);
134  if (rc != 0 && LastError() == WSANOTINITIALISED) {
135  return 0;
136  }
137 #else
138  auto rc = shutdown(fd, SHUT_RDWR);
139  if (rc != 0 && LastError() == ENOTCONN) {
140  return 0;
141  }
142 #endif
143  return rc;
144 }
145 
146 inline bool ErrorWouldBlock(std::int32_t errsv) noexcept(true) {
147 #ifdef _WIN32
148  return errsv == WSAEWOULDBLOCK;
149 #else
150  return errsv == EAGAIN || errsv == EWOULDBLOCK || errsv == EINPROGRESS;
151 #endif // _WIN32
152 }
153 
154 inline bool LastErrorWouldBlock() {
155  int errsv = LastError();
156  return ErrorWouldBlock(errsv);
157 }
158 
159 inline void SocketStartup() {
160 #if defined(_WIN32)
161  WSADATA wsa_data;
162  if (WSAStartup(MAKEWORD(2, 2), &wsa_data) == -1) {
163  ThrowAtError("WSAStartup");
164  }
165  if (LOBYTE(wsa_data.wVersion) != 2 || HIBYTE(wsa_data.wVersion) != 2) {
166  WSACleanup();
167  LOG(FATAL) << "Could not find a usable version of Winsock.dll";
168  }
169 #endif // defined(_WIN32)
170 }
171 
172 inline void SocketFinalize() {
173 #if defined(_WIN32)
174  WSACleanup();
175 #endif // defined(_WIN32)
176 }
177 
178 #if defined(_WIN32) && defined(xgboost_IS_MINGW)
179 // dummy definition for old mysys32.
180 inline const char *inet_ntop(int, const void *, char *, socklen_t) { // NOLINT
181  MingWError();
182  return nullptr;
183 }
184 #else
185 using ::inet_ntop;
186 #endif // defined(_WIN32) && defined(xgboost_IS_MINGW)
187 
188 } // namespace system
189 
190 namespace collective {
191 class SockAddress;
192 
193 enum class SockDomain : std::int32_t { kV4 = AF_INET, kV6 = AF_INET6 };
194 
199 SockAddress MakeSockAddress(StringView host, in_port_t port);
200 
201 class SockAddrV6 {
202  sockaddr_in6 addr_;
203 
204  public:
205  explicit SockAddrV6(sockaddr_in6 addr) : addr_{addr} {}
206  SockAddrV6() { std::memset(&addr_, '\0', sizeof(addr_)); }
207 
210 
211  in_port_t Port() const { return ntohs(addr_.sin6_port); }
212 
213  std::string Addr() const {
214  char buf[INET6_ADDRSTRLEN];
215  auto const *s = system::inet_ntop(static_cast<std::int32_t>(SockDomain::kV6), &addr_.sin6_addr,
216  buf, INET6_ADDRSTRLEN);
217  if (s == nullptr) {
218  system::ThrowAtError("inet_ntop");
219  }
220  return {buf};
221  }
222  sockaddr_in6 const &Handle() const { return addr_; }
223 };
224 
225 class SockAddrV4 {
226  private:
227  sockaddr_in addr_;
228 
229  public:
230  explicit SockAddrV4(sockaddr_in addr) : addr_{addr} {}
231  SockAddrV4() { std::memset(&addr_, '\0', sizeof(addr_)); }
232 
235 
236  [[nodiscard]] in_port_t Port() const { return ntohs(addr_.sin_port); }
237 
238  [[nodiscard]] std::string Addr() const {
239  char buf[INET_ADDRSTRLEN];
240  auto const *s = system::inet_ntop(static_cast<std::int32_t>(SockDomain::kV4), &addr_.sin_addr,
241  buf, INET_ADDRSTRLEN);
242  if (s == nullptr) {
243  system::ThrowAtError("inet_ntop");
244  }
245  return {buf};
246  }
247  [[nodiscard]] sockaddr_in const &Handle() const { return addr_; }
248 };
249 
253 class SockAddress {
254  private:
255  SockAddrV6 v6_;
256  SockAddrV4 v4_;
257  SockDomain domain_{SockDomain::kV4};
258 
259  public:
260  SockAddress() = default;
261  explicit SockAddress(SockAddrV6 const &addr) : v6_{addr}, domain_{SockDomain::kV6} {}
262  explicit SockAddress(SockAddrV4 const &addr) : v4_{addr} {}
263 
264  [[nodiscard]] auto Domain() const { return domain_; }
265 
266  [[nodiscard]] bool IsV4() const { return Domain() == SockDomain::kV4; }
267  [[nodiscard]] bool IsV6() const { return !IsV4(); }
268 
269  [[nodiscard]] auto const &V4() const { return v4_; }
270  [[nodiscard]] auto const &V6() const { return v6_; }
271 };
272 
276 class TCPSocket {
277  public:
279 
280  private:
281  HandleT handle_{InvalidSocket()};
282  bool non_blocking_{false};
283  // There's reliable no way to extract domain from a socket without first binding that
284  // socket on macos.
285 #if defined(__APPLE__)
286  SockDomain domain_{SockDomain::kV4};
287 #endif
288 
289  constexpr static HandleT InvalidSocket() { return -1; }
290 
291  explicit TCPSocket(HandleT newfd) : handle_{newfd} {}
292 
293  public:
294  TCPSocket() = default;
298  [[nodiscard]] auto Domain() const -> SockDomain {
299  auto ret_iafamily = [](std::int32_t domain) {
300  switch (domain) {
301  case AF_INET:
302  return SockDomain::kV4;
303  case AF_INET6:
304  return SockDomain::kV6;
305  default: {
306  LOG(FATAL) << "Unknown IA family.";
307  }
308  }
309  return SockDomain::kV4;
310  };
311 
312 #if defined(_WIN32)
313  WSAPROTOCOL_INFOA info;
314  socklen_t len = sizeof(info);
316  getsockopt(handle_, SOL_SOCKET, SO_PROTOCOL_INFO, reinterpret_cast<char *>(&info), &len),
317  0);
318  return ret_iafamily(info.iAddressFamily);
319 #elif defined(__APPLE__)
320  return domain_;
321 #elif defined(__unix__)
322 #ifndef __PASE__
323  std::int32_t domain;
324  socklen_t len = sizeof(domain);
326  getsockopt(this->Handle(), SOL_SOCKET, SO_DOMAIN, reinterpret_cast<char *>(&domain), &len),
327  0);
328  return ret_iafamily(domain);
329 #else
330  struct sockaddr sa;
331  socklen_t sizeofsa = sizeof(sa);
332  xgboost_CHECK_SYS_CALL(getsockname(handle_, &sa, &sizeofsa), 0);
333  if (sizeofsa < sizeof(uchar_t) * 2) {
334  return ret_iafamily(AF_INET);
335  }
336  return ret_iafamily(sa.sa_family);
337 #endif // __PASE__
338 #else
339  LOG(FATAL) << "Unknown platform.";
340  return ret_iafamily(AF_INET);
341 #endif // platforms
342  }
343 
344  [[nodiscard]] bool IsClosed() const { return handle_ == InvalidSocket(); }
345 
347  [[nodiscard]] Result GetSockError() const {
348  std::int32_t optval = 0;
349  socklen_t len = sizeof(optval);
350  auto ret = getsockopt(handle_, SOL_SOCKET, SO_ERROR, reinterpret_cast<char *>(&optval), &len);
351  if (ret != 0) {
352  auto errc = std::error_code{system::LastError(), std::system_category()};
353  return Fail("Failed to retrieve socket error.", std::move(errc));
354  }
355  if (optval != 0) {
356  auto errc = std::error_code{optval, std::system_category()};
357  return Fail("Socket error.", std::move(errc));
358  }
359  return Success();
360  }
361 
363  [[nodiscard]] bool BadSocket() const {
364  if (IsClosed()) {
365  return true;
366  }
367  auto err = GetSockError();
368  if (err.Code() == std::error_code{EBADF, std::system_category()} || // NOLINT
369  err.Code() == std::error_code{EINTR, std::system_category()}) { // NOLINT
370  return true;
371  }
372  return false;
373  }
374 
375  [[nodiscard]] Result NonBlocking(bool non_block) {
376 #if defined(_WIN32)
377  u_long mode = non_block ? 1 : 0;
378  if (ioctlsocket(handle_, FIONBIO, &mode) != NO_ERROR) {
379  return system::FailWithCode("Failed to set socket to non-blocking.");
380  }
381 #else
382  std::int32_t flag = fcntl(handle_, F_GETFL, 0);
383  auto rc = flag;
384  if (rc == -1) {
385  return system::FailWithCode("Failed to get socket flag.");
386  }
387  if (non_block) {
388  flag |= O_NONBLOCK;
389  } else {
390  flag &= ~O_NONBLOCK;
391  }
392  rc = fcntl(handle_, F_SETFL, flag);
393  if (rc == -1) {
394  return system::FailWithCode("Failed to set socket to non-blocking.");
395  }
396 #endif // _WIN32
397  non_blocking_ = non_block;
398  return Success();
399  }
400  [[nodiscard]] bool NonBlocking() const { return non_blocking_; }
401  [[nodiscard]] Result RecvTimeout(std::chrono::seconds timeout) {
402  // https://stackoverflow.com/questions/2876024/linux-is-there-a-read-or-recv-from-socket-with-timeout
403 #if defined(_WIN32)
404  DWORD tv = timeout.count() * 1000;
405  auto rc =
406  setsockopt(Handle(), SOL_SOCKET, SO_RCVTIMEO, reinterpret_cast<char *>(&tv), sizeof(tv));
407 #else
408  struct timeval tv;
409  tv.tv_sec = timeout.count();
410  tv.tv_usec = 0;
411  auto rc = setsockopt(Handle(), SOL_SOCKET, SO_RCVTIMEO, reinterpret_cast<char const *>(&tv),
412  sizeof(tv));
413 #endif
414  if (rc != 0) {
415  return system::FailWithCode("Failed to set timeout on recv.");
416  }
417  return Success();
418  }
419 
420  [[nodiscard]] Result SetBufSize(std::int32_t n_bytes) {
421  auto rc = setsockopt(this->Handle(), SOL_SOCKET, SO_SNDBUF, reinterpret_cast<char *>(&n_bytes),
422  sizeof(n_bytes));
423  if (rc != 0) {
424  return system::FailWithCode("Failed to set send buffer size.");
425  }
426  rc = setsockopt(this->Handle(), SOL_SOCKET, SO_RCVBUF, reinterpret_cast<char *>(&n_bytes),
427  sizeof(n_bytes));
428  if (rc != 0) {
429  return system::FailWithCode("Failed to set recv buffer size.");
430  }
431  return Success();
432  }
433 
434  [[nodiscard]] Result SendBufSize(std::int32_t *n_bytes) {
435  socklen_t optlen;
436  auto rc = getsockopt(this->Handle(), SOL_SOCKET, SO_SNDBUF, reinterpret_cast<char *>(n_bytes),
437  &optlen);
438  if (rc != 0 || optlen != sizeof(std::int32_t)) {
439  return system::FailWithCode("getsockopt");
440  }
441  return Success();
442  }
443  [[nodiscard]] Result RecvBufSize(std::int32_t *n_bytes) {
444  socklen_t optlen;
445  auto rc = getsockopt(this->Handle(), SOL_SOCKET, SO_RCVBUF, reinterpret_cast<char *>(n_bytes),
446  &optlen);
447  if (rc != 0 || optlen != sizeof(std::int32_t)) {
448  return system::FailWithCode("getsockopt");
449  }
450  return Success();
451  }
452 #if defined(__linux__)
453  [[nodiscard]] Result PendingSendSize(std::int32_t *n_bytes) const {
454  return ioctl(this->Handle(), TIOCOUTQ, n_bytes) == 0 ? Success()
455  : system::FailWithCode("ioctl");
456  }
457  [[nodiscard]] Result PendingRecvSize(std::int32_t *n_bytes) const {
458  return ioctl(this->Handle(), FIONREAD, n_bytes) == 0 ? Success()
459  : system::FailWithCode("ioctl");
460  }
461 #endif // defined(__linux__)
462 
463  [[nodiscard]] Result SetKeepAlive() {
464  std::int32_t keepalive = 1;
465  auto rc = setsockopt(handle_, SOL_SOCKET, SO_KEEPALIVE, reinterpret_cast<char *>(&keepalive),
466  sizeof(keepalive));
467  if (rc != 0) {
468  return system::FailWithCode("Failed to set TCP keeaplive.");
469  }
470  return Success();
471  }
472 
473  [[nodiscard]] Result SetNoDelay(std::int32_t no_delay = 1) {
474  auto rc = setsockopt(handle_, IPPROTO_TCP, TCP_NODELAY, reinterpret_cast<char *>(&no_delay),
475  sizeof(no_delay));
476  if (rc != 0) {
477  return system::FailWithCode("Failed to set TCP no delay.");
478  }
479  return Success();
480  }
481 
486  SockAddress addr;
487  TCPSocket newsock;
488  auto rc = this->Accept(&newsock, &addr);
489  SafeColl(rc);
490  return newsock;
491  }
492 
493  [[nodiscard]] Result Accept(TCPSocket *out, SockAddress *addr) {
494 #if defined(_WIN32)
495  auto interrupt = WSAEINTR;
496 #else
497  auto interrupt = EINTR;
498 #endif
499  if (this->Domain() == SockDomain::kV4) {
500  struct sockaddr_in caddr;
501  socklen_t caddr_len = sizeof(caddr);
502  HandleT newfd = accept(Handle(), reinterpret_cast<sockaddr *>(&caddr), &caddr_len);
503  if (newfd == InvalidSocket() && system::LastError() != interrupt) {
504  return system::FailWithCode("Failed to accept.");
505  }
506  *addr = SockAddress{SockAddrV4{caddr}};
507  *out = TCPSocket{newfd};
508  } else {
509  struct sockaddr_in6 caddr;
510  socklen_t caddr_len = sizeof(caddr);
511  HandleT newfd = accept(Handle(), reinterpret_cast<sockaddr *>(&caddr), &caddr_len);
512  if (newfd == InvalidSocket() && system::LastError() != interrupt) {
513  return system::FailWithCode("Failed to accept.");
514  }
515  *addr = SockAddress{SockAddrV6{caddr}};
516  *out = TCPSocket{newfd};
517  }
518  // On MacOS, this is automatically set to async socket if the parent socket is async
519  // We make sure all socket are blocking by default.
520  //
521  // On Windows, a closed socket is returned during shutdown. We guard against it when
522  // setting non-blocking.
523  if (!out->IsClosed()) {
524  return out->NonBlocking(false);
525  }
526  return Success();
527  }
528 
530  if (!IsClosed()) {
531  auto rc = this->Close();
532  if (!rc.OK()) {
533  LOG(WARNING) << rc.Report();
534  }
535  }
536  }
537 
538  TCPSocket(TCPSocket const &that) = delete;
539  TCPSocket(TCPSocket &&that) noexcept(true) { std::swap(this->handle_, that.handle_); }
540  TCPSocket &operator=(TCPSocket const &that) = delete;
541  TCPSocket &operator=(TCPSocket &&that) noexcept(true) {
542  std::swap(this->handle_, that.handle_);
543  return *this;
544  }
548  [[nodiscard]] HandleT const &Handle() const { return handle_; }
552  [[nodiscard]] Result Listen(std::int32_t backlog = 16) {
553  if (listen(handle_, backlog) != 0) {
554  return system::FailWithCode("Failed to listen.");
555  }
556  return Success();
557  }
561  [[nodiscard]] Result BindHost(std::int32_t* p_out) {
562  // Use int32 instead of in_port_t for consistency. We take port as parameter from
563  // users using other languages, the port is usually stored and passed around as int.
564  if (Domain() == SockDomain::kV6) {
565  auto addr = SockAddrV6::InaddrAny();
566  auto handle = reinterpret_cast<sockaddr const *>(&addr.Handle());
567  if (bind(handle_, handle, sizeof(std::remove_reference_t<decltype(addr.Handle())>)) != 0) {
568  return system::FailWithCode("bind failed.");
569  }
570 
571  sockaddr_in6 res_addr;
572  socklen_t addrlen = sizeof(res_addr);
573  if (getsockname(handle_, reinterpret_cast<sockaddr *>(&res_addr), &addrlen) != 0) {
574  return system::FailWithCode("getsockname failed.");
575  }
576  *p_out = ntohs(res_addr.sin6_port);
577  } else {
578  auto addr = SockAddrV4::InaddrAny();
579  auto handle = reinterpret_cast<sockaddr const *>(&addr.Handle());
580  if (bind(handle_, handle, sizeof(std::remove_reference_t<decltype(addr.Handle())>)) != 0) {
581  return system::FailWithCode("bind failed.");
582  }
583 
584  sockaddr_in res_addr;
585  socklen_t addrlen = sizeof(res_addr);
586  if (getsockname(handle_, reinterpret_cast<sockaddr *>(&res_addr), &addrlen) != 0) {
587  return system::FailWithCode("getsockname failed.");
588  }
589  *p_out = ntohs(res_addr.sin_port);
590  }
591 
592  return Success();
593  }
594 
595  [[nodiscard]] auto Port() const {
596  if (this->Domain() == SockDomain::kV4) {
597  sockaddr_in res_addr;
598  socklen_t addrlen = sizeof(res_addr);
599  auto code = getsockname(handle_, reinterpret_cast<sockaddr *>(&res_addr), &addrlen);
600  if (code != 0) {
601  return std::make_pair(system::FailWithCode("getsockname"), std::int32_t{0});
602  }
603  return std::make_pair(Success(), std::int32_t{ntohs(res_addr.sin_port)});
604  } else {
605  sockaddr_in6 res_addr;
606  socklen_t addrlen = sizeof(res_addr);
607  auto code = getsockname(handle_, reinterpret_cast<sockaddr *>(&res_addr), &addrlen);
608  if (code != 0) {
609  return std::make_pair(system::FailWithCode("getsockname"), std::int32_t{0});
610  }
611  return std::make_pair(Success(), std::int32_t{ntohs(res_addr.sin6_port)});
612  }
613  }
614 
615  [[nodiscard]] Result Bind(StringView ip, std::int32_t *port) {
616  // bind socket handle_ to ip
617  auto addr = MakeSockAddress(ip, 0);
618  std::int32_t errc{0};
619  if (addr.IsV4()) {
620  auto handle = reinterpret_cast<sockaddr const *>(&addr.V4().Handle());
621  errc = bind(handle_, handle, sizeof(std::remove_reference_t<decltype(addr.V4().Handle())>));
622  } else {
623  auto handle = reinterpret_cast<sockaddr const *>(&addr.V6().Handle());
624  errc = bind(handle_, handle, sizeof(std::remove_reference_t<decltype(addr.V6().Handle())>));
625  }
626  if (errc != 0) {
627  return system::FailWithCode("Failed to bind socket.");
628  }
629  auto [rc, new_port] = this->Port();
630  if (!rc.OK()) {
631  return std::move(rc);
632  }
633  *port = new_port;
634  return Success();
635  }
636 
640  [[nodiscard]] Result SendAll(void const *buf, std::size_t len, std::size_t *n_sent) {
641  char const *_buf = reinterpret_cast<const char *>(buf);
642  std::size_t &ndone = *n_sent;
643  ndone = 0;
644  while (ndone < len) {
645  ssize_t ret = send(handle_, _buf, len - ndone, 0);
646  if (ret == -1) {
648  return Success();
649  }
650  return system::FailWithCode("send");
651  }
652  _buf += ret;
653  ndone += ret;
654  }
655  return Success();
656  }
660  [[nodiscard]] Result RecvAll(void *buf, std::size_t len, std::size_t *n_recv) {
661  char *_buf = reinterpret_cast<char *>(buf);
662  std::size_t &ndone = *n_recv;
663  ndone = 0;
664  while (ndone < len) {
665  ssize_t ret = recv(handle_, _buf, len - ndone, MSG_WAITALL);
666  if (ret == -1) {
668  return Success();
669  }
670  return system::FailWithCode("recv");
671  }
672  if (ret == 0) {
673  return Success();
674  }
675  _buf += ret;
676  ndone += ret;
677  }
678  return Success();
679  }
687  auto Send(const void *buf_, std::size_t len, std::int32_t flags = 0) {
688  const char *buf = reinterpret_cast<const char *>(buf_);
689  return send(handle_, buf, len, flags);
690  }
698  auto Recv(void *buf, std::size_t len, std::int32_t flags = 0) {
699  char *_buf = reinterpret_cast<char *>(buf);
700  return recv(handle_, _buf, len, flags);
701  }
705  std::size_t Send(StringView str);
709  [[nodiscard]] Result Recv(std::string *p_str);
713  [[nodiscard]] Result Close() {
714  if (InvalidSocket() != handle_) {
715  auto rc = system::CloseSocket(handle_);
716 #if defined(_WIN32)
717  // it's possible that we close TCP sockets after finalizing WSA due to detached thread.
718  if (rc != 0 && system::LastError() != WSANOTINITIALISED) {
719  return system::FailWithCode("Failed to close the socket.");
720  }
721 #else
722  if (rc != 0) {
723  return system::FailWithCode("Failed to close the socket.");
724  }
725 #endif
726  handle_ = InvalidSocket();
727  }
728  return Success();
729  }
733  [[nodiscard]] Result Shutdown() {
734  if (this->IsClosed()) {
735  return Success();
736  }
737  auto rc = system::ShutdownSocket(this->Handle());
738 #if defined(_WIN32)
739  // Windows cannot shutdown a socket if it's not connected.
740  if (rc == -1 && system::LastError() == WSAENOTCONN) {
741  return Success();
742  }
743 #endif
744  if (rc != 0) {
745  return system::FailWithCode("Failed to shutdown socket.");
746  }
747  return Success();
748  }
749 
753  static TCPSocket Create(SockDomain domain) {
754 #if defined(xgboost_IS_MINGW)
755  MingWError();
756  return {};
757 #else
758  auto fd = socket(static_cast<std::int32_t>(domain), SOCK_STREAM, 0);
759  if (fd == InvalidSocket()) {
760  system::ThrowAtError("socket");
761  }
762 
763  TCPSocket socket{fd};
764 #if defined(__APPLE__)
765  socket.domain_ = domain;
766 #endif // defined(__APPLE__)
767  return socket;
768 #endif // defined(xgboost_IS_MINGW)
769  }
770 
771  static TCPSocket *CreatePtr(SockDomain domain) {
772 #if defined(xgboost_IS_MINGW)
773  MingWError();
774  return nullptr;
775 #else
776  auto fd = socket(static_cast<std::int32_t>(domain), SOCK_STREAM, 0);
777  if (fd == InvalidSocket()) {
778  system::ThrowAtError("socket");
779  }
780  auto socket = new TCPSocket{fd};
781 
782 #if defined(__APPLE__)
783  socket->domain_ = domain;
784 #endif // defined(__APPLE__)
785  return socket;
786 #endif // defined(xgboost_IS_MINGW)
787  }
788 };
789 
802 [[nodiscard]] Result Connect(xgboost::StringView host, std::int32_t port, std::int32_t retry,
803  std::chrono::seconds timeout,
805 
809 [[nodiscard]] Result GetHostName(std::string *p_out);
810 
814 template <typename H>
815 Result INetNToP(H const &host, std::string *p_out) {
816  std::string &ip = *p_out;
817  switch (host->h_addrtype) {
818  case AF_INET: {
819  auto addr = reinterpret_cast<struct in_addr *>(host->h_addr_list[0]);
820  char str[INET_ADDRSTRLEN];
821  inet_ntop(AF_INET, addr, str, INET_ADDRSTRLEN);
822  ip = str;
823  break;
824  }
825  case AF_INET6: {
826  auto addr = reinterpret_cast<struct in6_addr *>(host->h_addr_list[0]);
827  char str[INET6_ADDRSTRLEN];
828  inet_ntop(AF_INET6, addr, str, INET6_ADDRSTRLEN);
829  ip = str;
830  break;
831  }
832  default: {
833  return Fail("Invalid address type.");
834  }
835  }
836  return Success();
837 }
838 } // namespace collective
839 } // namespace xgboost
840 
841 #undef xgboost_CHECK_SYS_CALL
842 
843 #if defined(xgboost_IS_MINGW)
844 #undef xgboost_IS_MINGW
845 #endif
Defines configuration macros and basic types for xgboost.
Definition: socket.h:225
SockAddrV4(sockaddr_in addr)
Definition: socket.h:230
static SockAddrV4 InaddrAny()
in_port_t Port() const
Definition: socket.h:236
sockaddr_in const & Handle() const
Definition: socket.h:247
std::string Addr() const
Definition: socket.h:238
static SockAddrV4 Loopback()
SockAddrV4()
Definition: socket.h:231
Definition: socket.h:201
static SockAddrV6 InaddrAny()
SockAddrV6()
Definition: socket.h:206
sockaddr_in6 const & Handle() const
Definition: socket.h:222
in_port_t Port() const
Definition: socket.h:211
SockAddrV6(sockaddr_in6 addr)
Definition: socket.h:205
std::string Addr() const
Definition: socket.h:213
static SockAddrV6 Loopback()
Address for TCP socket, can be either IPv4 or IPv6.
Definition: socket.h:253
bool IsV6() const
Definition: socket.h:267
auto const & V6() const
Definition: socket.h:270
bool IsV4() const
Definition: socket.h:266
auto Domain() const
Definition: socket.h:264
SockAddress(SockAddrV4 const &addr)
Definition: socket.h:262
auto const & V4() const
Definition: socket.h:269
SockAddress(SockAddrV6 const &addr)
Definition: socket.h:261
TCP socket for simple communication.
Definition: socket.h:276
Result GetSockError() const
get last error code if any
Definition: socket.h:347
Result Recv(std::string *p_str)
Receive string, format is matched with the Python socket wrapper in RABIT.
Result RecvTimeout(std::chrono::seconds timeout)
Definition: socket.h:401
HandleT const & Handle() const
Return the native socket file descriptor.
Definition: socket.h:548
Result SetNoDelay(std::int32_t no_delay=1)
Definition: socket.h:473
TCPSocket & operator=(TCPSocket const &that)=delete
Result BindHost(std::int32_t *p_out)
Bind socket to INADDR_ANY, return the port selected by the OS.
Definition: socket.h:561
static TCPSocket * CreatePtr(SockDomain domain)
Definition: socket.h:771
Result Shutdown()
Call shutdown on the socket.
Definition: socket.h:733
system::SocketT HandleT
Definition: socket.h:278
Result Bind(StringView ip, std::int32_t *port)
Definition: socket.h:615
TCPSocket & operator=(TCPSocket &&that) noexcept(true)
Definition: socket.h:541
auto Port() const
Definition: socket.h:595
Result SetKeepAlive()
Definition: socket.h:463
Result Accept(TCPSocket *out, SockAddress *addr)
Definition: socket.h:493
TCPSocket(TCPSocket const &that)=delete
Result RecvBufSize(std::int32_t *n_bytes)
Definition: socket.h:443
Result SendBufSize(std::int32_t *n_bytes)
Definition: socket.h:434
std::size_t Send(StringView str)
Send string, format is matched with the Python socket wrapper in RABIT.
auto Domain() const -> SockDomain
Return the socket domain.
Definition: socket.h:298
Result Listen(std::int32_t backlog=16)
Listen to incoming requests. Should be called after bind.
Definition: socket.h:552
static TCPSocket Create(SockDomain domain)
Create a TCP socket on specified domain.
Definition: socket.h:753
Result RecvAll(void *buf, std::size_t len, std::size_t *n_recv)
Receive data, without error then all data should be received.
Definition: socket.h:660
auto Recv(void *buf, std::size_t len, std::int32_t flags=0)
receive data using the socket
Definition: socket.h:698
bool NonBlocking() const
Definition: socket.h:400
bool IsClosed() const
Definition: socket.h:344
Result NonBlocking(bool non_block)
Definition: socket.h:375
TCPSocket Accept()
Accept new connection, returns a new TCP socket for the new connection.
Definition: socket.h:485
~TCPSocket()
Definition: socket.h:529
auto Send(const void *buf_, std::size_t len, std::int32_t flags=0)
Send data using the socket.
Definition: socket.h:687
Result SendAll(void const *buf, std::size_t len, std::size_t *n_sent)
Send data, without error then all data should be sent.
Definition: socket.h:640
bool BadSocket() const
check if anything bad happens
Definition: socket.h:363
TCPSocket(TCPSocket &&that) noexcept(true)
Definition: socket.h:539
Result SetBufSize(std::int32_t n_bytes)
Definition: socket.h:420
Result Close()
Close the socket, called automatically in destructor if the socket is not closed.
Definition: socket.h:713
void swap(xgboost::IntrusivePtr< T > &x, xgboost::IntrusivePtr< T > &y) noexcept
Definition: intrusive_ptr.h:209
Result Connect(xgboost::StringView host, std::int32_t port, std::int32_t retry, std::chrono::seconds timeout, xgboost::collective::TCPSocket *out_conn)
Connect to remote address, returns the error code if failed.
Result GetHostName(std::string *p_out)
Get the local host name.
SockAddress MakeSockAddress(StringView host, in_port_t port)
Parse host address and return a SockAddress instance. Supports IPv4 and IPv6 host.
SockDomain
Definition: socket.h:193
auto Fail(std::string msg, char const *file=__builtin_FILE(), std::int32_t line=__builtin_LINE())
Return failure.
Definition: result.h:124
void SafeColl(Result const &rc)
Result INetNToP(H const &host, std::string *p_out)
inet_ntop
Definition: socket.h:815
auto Success() noexcept(true)
Return success.
Definition: result.h:120
bool ErrorWouldBlock(std::int32_t errsv) noexcept(true)
Definition: socket.h:146
auto ThrowAtError(StringView fn_name, std::int32_t errsv=LastError())
Definition: socket.h:102
void SocketStartup()
Definition: socket.h:159
std::int32_t CloseSocket(SocketT fd)
Definition: socket.h:123
bool LastErrorWouldBlock()
Definition: socket.h:154
std::int32_t LastError()
Definition: socket.h:79
void SocketFinalize()
Definition: socket.h:172
std::int32_t ShutdownSocket(SocketT fd)
Definition: socket.h:131
collective::Result FailWithCode(std::string msg)
Definition: socket.h:88
int SocketT
Definition: socket.h:111
Core data structure for multi-target trees.
Definition: base.h:87
#define __builtin_LINE()
Definition: result.h:57
#define __builtin_FILE()
Definition: result.h:56
#define xgboost_CHECK_SYS_CALL(exp, expected)
Definition: socket.h:115
Definition: string_view.h:16
An error type that's easier to handle than throwing dmlc exception. We can record and propagate the s...
Definition: result.h:67