xgboost
socket.h
Go to the documentation of this file.
1 
4 #pragma once
5 
6 #if !defined(NOMINMAX) && defined(_WIN32)
7 #define NOMINMAX
8 #endif // !defined(NOMINMAX)
9 
10 #include <cerrno> // errno, EINTR, EBADF
11 #include <climits> // HOST_NAME_MAX
12 #include <cstddef> // std::size_t
13 #include <cstdint> // std::int32_t, std::uint16_t
14 #include <cstring> // memset
15 #include <limits> // std::numeric_limits
16 #include <string> // std::string
17 #include <system_error> // std::error_code, std::system_category
18 #include <utility> // std::swap
19 
20 #if !defined(xgboost_IS_MINGW)
21 #define xgboost_IS_MINGW() defined(__MINGW32__)
22 #endif // xgboost_IS_MINGW
23 
24 #if defined(_WIN32)
25 
26 #include <winsock2.h>
27 #include <ws2tcpip.h>
28 
29 using in_port_t = std::uint16_t;
30 
31 #ifdef _MSC_VER
32 #pragma comment(lib, "Ws2_32.lib")
33 #endif // _MSC_VER
34 
35 #if !xgboost_IS_MINGW()
36 using ssize_t = int;
37 #endif // !xgboost_IS_MINGW()
38 
39 #else // UNIX
40 
41 #include <arpa/inet.h> // inet_ntop
42 #include <fcntl.h> // fcntl, F_GETFL, O_NONBLOCK
43 #include <netinet/in.h> // sockaddr_in6, sockaddr_in, in_port_t, INET6_ADDRSTRLEN, INET_ADDRSTRLEN
44 #include <netinet/in.h> // IPPROTO_TCP
45 #include <netinet/tcp.h> // TCP_NODELAY
46 #include <sys/socket.h> // socket, SOL_SOCKET, SO_ERROR, MSG_WAITALL, recv, send, AF_INET6, AF_INET
47 #include <unistd.h> // close
48 
49 #if defined(__sun) || defined(sun)
50 #include <sys/sockio.h>
51 #endif // defined(__sun) || defined(sun)
52 
53 #endif // defined(_WIN32)
54 
55 #include "xgboost/base.h" // XGBOOST_EXPECT
56 #include "xgboost/logging.h" // LOG
57 #include "xgboost/string_view.h" // StringView
58 
59 #if !defined(HOST_NAME_MAX)
60 #define HOST_NAME_MAX 256 // macos
61 #endif
62 
63 namespace xgboost {
64 
65 #if xgboost_IS_MINGW()
66 // see the dummy implementation of `poll` in rabit for more info.
67 inline void MingWError() { LOG(FATAL) << "Distributed training on mingw is not supported."; }
68 #endif // xgboost_IS_MINGW()
69 
70 namespace system {
71 inline std::int32_t LastError() {
72 #if defined(_WIN32)
73  return WSAGetLastError();
74 #else
75  int errsv = errno;
76  return errsv;
77 #endif
78 }
79 
80 #if defined(__GLIBC__)
81 inline auto ThrowAtError(StringView fn_name, std::int32_t errsv = LastError(),
82  std::int32_t line = __builtin_LINE(),
83  char const *file = __builtin_FILE()) {
84  auto err = std::error_code{errsv, std::system_category()};
85  LOG(FATAL) << "\n"
86  << file << "(" << line << "): Failed to call `" << fn_name << "`: " << err.message()
87  << std::endl;
88 }
89 #else
90 inline auto ThrowAtError(StringView fn_name, std::int32_t errsv = LastError()) {
91  auto err = std::error_code{errsv, std::system_category()};
92  LOG(FATAL) << "Failed to call `" << fn_name << "`: " << err.message() << std::endl;
93 }
94 #endif // defined(__GLIBC__)
95 
96 #if defined(_WIN32)
97 using SocketT = SOCKET;
98 #else
99 using SocketT = int;
100 #endif // defined(_WIN32)
101 
102 #if !defined(xgboost_CHECK_SYS_CALL)
103 #define xgboost_CHECK_SYS_CALL(exp, expected) \
104  do { \
105  if (XGBOOST_EXPECT((exp) != (expected), false)) { \
106  ::xgboost::system::ThrowAtError(#exp); \
107  } \
108  } while (false)
109 #endif // !defined(xgboost_CHECK_SYS_CALL)
110 
111 inline std::int32_t CloseSocket(SocketT fd) {
112 #if defined(_WIN32)
113  return closesocket(fd);
114 #else
115  return close(fd);
116 #endif
117 }
118 
119 inline bool LastErrorWouldBlock() {
120  int errsv = LastError();
121 #ifdef _WIN32
122  return errsv == WSAEWOULDBLOCK;
123 #else
124  return errsv == EAGAIN || errsv == EWOULDBLOCK;
125 #endif // _WIN32
126 }
127 
128 inline void SocketStartup() {
129 #if defined(_WIN32)
130  WSADATA wsa_data;
131  if (WSAStartup(MAKEWORD(2, 2), &wsa_data) == -1) {
132  ThrowAtError("WSAStartup");
133  }
134  if (LOBYTE(wsa_data.wVersion) != 2 || HIBYTE(wsa_data.wVersion) != 2) {
135  WSACleanup();
136  LOG(FATAL) << "Could not find a usable version of Winsock.dll";
137  }
138 #endif // defined(_WIN32)
139 }
140 
141 inline void SocketFinalize() {
142 #if defined(_WIN32)
143  WSACleanup();
144 #endif // defined(_WIN32)
145 }
146 
147 #if defined(_WIN32) && xgboost_IS_MINGW()
148 // dummy definition for old mysys32.
149 inline const char *inet_ntop(int, const void *, char *, socklen_t) { // NOLINT
150  MingWError();
151  return nullptr;
152 }
153 #else
154 using ::inet_ntop;
155 #endif
156 
157 } // namespace system
158 
159 namespace collective {
160 class SockAddress;
161 
162 enum class SockDomain : std::int32_t { kV4 = AF_INET, kV6 = AF_INET6 };
163 
168 SockAddress MakeSockAddress(StringView host, in_port_t port);
169 
170 class SockAddrV6 {
171  sockaddr_in6 addr_;
172 
173  public:
174  explicit SockAddrV6(sockaddr_in6 addr) : addr_{addr} {}
175  SockAddrV6() { std::memset(&addr_, '\0', sizeof(addr_)); }
176 
179 
180  in_port_t Port() const { return ntohs(addr_.sin6_port); }
181 
182  std::string Addr() const {
183  char buf[INET6_ADDRSTRLEN];
184  auto const *s = system::inet_ntop(static_cast<std::int32_t>(SockDomain::kV6), &addr_.sin6_addr,
185  buf, INET6_ADDRSTRLEN);
186  if (s == nullptr) {
187  system::ThrowAtError("inet_ntop");
188  }
189  return {buf};
190  }
191  sockaddr_in6 const &Handle() const { return addr_; }
192 };
193 
194 class SockAddrV4 {
195  private:
196  sockaddr_in addr_;
197 
198  public:
199  explicit SockAddrV4(sockaddr_in addr) : addr_{addr} {}
200  SockAddrV4() { std::memset(&addr_, '\0', sizeof(addr_)); }
201 
204 
205  in_port_t Port() const { return ntohs(addr_.sin_port); }
206 
207  std::string Addr() const {
208  char buf[INET_ADDRSTRLEN];
209  auto const *s = system::inet_ntop(static_cast<std::int32_t>(SockDomain::kV4), &addr_.sin_addr,
210  buf, INET_ADDRSTRLEN);
211  if (s == nullptr) {
212  system::ThrowAtError("inet_ntop");
213  }
214  return {buf};
215  }
216  sockaddr_in const &Handle() const { return addr_; }
217 };
218 
222 class SockAddress {
223  private:
224  SockAddrV6 v6_;
225  SockAddrV4 v4_;
226  SockDomain domain_{SockDomain::kV4};
227 
228  public:
229  SockAddress() = default;
230  explicit SockAddress(SockAddrV6 const &addr) : v6_{addr}, domain_{SockDomain::kV6} {}
231  explicit SockAddress(SockAddrV4 const &addr) : v4_{addr} {}
232 
233  auto Domain() const { return domain_; }
234 
235  bool IsV4() const { return Domain() == SockDomain::kV4; }
236  bool IsV6() const { return !IsV4(); }
237 
238  auto const &V4() const { return v4_; }
239  auto const &V6() const { return v6_; }
240 };
241 
245 class TCPSocket {
246  public:
248 
249  private:
250  HandleT handle_{InvalidSocket()};
251  // There's reliable no way to extract domain from a socket without first binding that
252  // socket on macos.
253 #if defined(__APPLE__)
254  SockDomain domain_{SockDomain::kV4};
255 #endif
256 
257  constexpr static HandleT InvalidSocket() { return -1; }
258 
259  explicit TCPSocket(HandleT newfd) : handle_{newfd} {}
260 
261  public:
262  TCPSocket() = default;
266  auto Domain() const -> SockDomain {
267  auto ret_iafamily = [](std::int32_t domain) {
268  switch (domain) {
269  case AF_INET:
270  return SockDomain::kV4;
271  case AF_INET6:
272  return SockDomain::kV6;
273  default: {
274  LOG(FATAL) << "Unknown IA family.";
275  }
276  }
277  return SockDomain::kV4;
278  };
279 
280 #if defined(_WIN32)
281  WSAPROTOCOL_INFOA info;
282  socklen_t len = sizeof(info);
284  getsockopt(handle_, SOL_SOCKET, SO_PROTOCOL_INFO, reinterpret_cast<char *>(&info), &len),
285  0);
286  return ret_iafamily(info.iAddressFamily);
287 #elif defined(__APPLE__)
288  return domain_;
289 #elif defined(__unix__)
290 #ifndef __PASE__
291  std::int32_t domain;
292  socklen_t len = sizeof(domain);
294  getsockopt(handle_, SOL_SOCKET, SO_DOMAIN, reinterpret_cast<char *>(&domain), &len), 0);
295  return ret_iafamily(domain);
296 #else
297  struct sockaddr sa;
298  socklen_t sizeofsa = sizeof(sa);
300  getsockname(handle_, &sa, &sizeofsa), 0);
301  if (sizeofsa < sizeof(uchar_t)*2) {
302  return ret_iafamily(AF_INET);
303  }
304  return ret_iafamily(sa.sa_family);
305 #endif // __PASE__
306 #else
307  LOG(FATAL) << "Unknown platform.";
308  return ret_iafamily(AF_INET);
309 #endif // platforms
310  }
311 
312  bool IsClosed() const { return handle_ == InvalidSocket(); }
313 
315  std::int32_t GetSockError() const {
316  std::int32_t error = 0;
317  socklen_t len = sizeof(error);
319  getsockopt(handle_, SOL_SOCKET, SO_ERROR, reinterpret_cast<char *>(&error), &len), 0);
320  return error;
321  }
323  bool BadSocket() const {
324  if (IsClosed()) return true;
325  std::int32_t err = GetSockError();
326  if (err == EBADF || err == EINTR) return true;
327  return false;
328  }
329 
330  void SetNonBlock() {
331  bool non_block{true};
332 #if defined(_WIN32)
333  u_long mode = non_block ? 1 : 0;
334  xgboost_CHECK_SYS_CALL(ioctlsocket(handle_, FIONBIO, &mode), NO_ERROR);
335 #else
336  std::int32_t flag = fcntl(handle_, F_GETFL, 0);
337  if (flag == -1) {
338  system::ThrowAtError("fcntl");
339  }
340  if (non_block) {
341  flag |= O_NONBLOCK;
342  } else {
343  flag &= ~O_NONBLOCK;
344  }
345  if (fcntl(handle_, F_SETFL, flag) == -1) {
346  system::ThrowAtError("fcntl");
347  }
348 #endif // _WIN32
349  }
350 
351  void SetKeepAlive() {
352  std::int32_t keepalive = 1;
353  xgboost_CHECK_SYS_CALL(setsockopt(handle_, SOL_SOCKET, SO_KEEPALIVE,
354  reinterpret_cast<char *>(&keepalive), sizeof(keepalive)),
355  0);
356  }
357 
358  void SetNoDelay() {
359  std::int32_t tcp_no_delay = 1;
361  setsockopt(handle_, IPPROTO_TCP, TCP_NODELAY, reinterpret_cast<char *>(&tcp_no_delay),
362  sizeof(tcp_no_delay)),
363  0);
364  }
365 
370  HandleT newfd = accept(handle_, nullptr, nullptr);
371  if (newfd == InvalidSocket()) {
372  system::ThrowAtError("accept");
373  }
374  TCPSocket newsock{newfd};
375  return newsock;
376  }
377 
379  if (!IsClosed()) {
380  Close();
381  }
382  }
383 
384  TCPSocket(TCPSocket const &that) = delete;
385  TCPSocket(TCPSocket &&that) noexcept(true) { std::swap(this->handle_, that.handle_); }
386  TCPSocket &operator=(TCPSocket const &that) = delete;
388  std::swap(this->handle_, that.handle_);
389  return *this;
390  }
394  HandleT const &Handle() const { return handle_; }
398  void Listen(std::int32_t backlog = 16) { xgboost_CHECK_SYS_CALL(listen(handle_, backlog), 0); }
402  in_port_t BindHost() {
403  if (Domain() == SockDomain::kV6) {
404  auto addr = SockAddrV6::InaddrAny();
405  auto handle = reinterpret_cast<sockaddr const *>(&addr.Handle());
407  bind(handle_, handle, sizeof(std::remove_reference_t<decltype(addr.Handle())>)), 0);
408 
409  sockaddr_in6 res_addr;
410  socklen_t addrlen = sizeof(res_addr);
412  getsockname(handle_, reinterpret_cast<sockaddr *>(&res_addr), &addrlen), 0);
413  return ntohs(res_addr.sin6_port);
414  } else {
415  auto addr = SockAddrV4::InaddrAny();
416  auto handle = reinterpret_cast<sockaddr const *>(&addr.Handle());
418  bind(handle_, handle, sizeof(std::remove_reference_t<decltype(addr.Handle())>)), 0);
419 
420  sockaddr_in res_addr;
421  socklen_t addrlen = sizeof(res_addr);
423  getsockname(handle_, reinterpret_cast<sockaddr *>(&res_addr), &addrlen), 0);
424  return ntohs(res_addr.sin_port);
425  }
426  }
430  auto SendAll(void const *buf, std::size_t len) {
431  char const *_buf = reinterpret_cast<const char *>(buf);
432  std::size_t ndone = 0;
433  while (ndone < len) {
434  ssize_t ret = send(handle_, _buf, len - ndone, 0);
435  if (ret == -1) {
437  return ndone;
438  }
439  system::ThrowAtError("send");
440  }
441  _buf += ret;
442  ndone += ret;
443  }
444  return ndone;
445  }
449  auto RecvAll(void *buf, std::size_t len) {
450  char *_buf = reinterpret_cast<char *>(buf);
451  std::size_t ndone = 0;
452  while (ndone < len) {
453  ssize_t ret = recv(handle_, _buf, len - ndone, MSG_WAITALL);
454  if (ret == -1) {
456  return ndone;
457  }
458  system::ThrowAtError("recv");
459  }
460  if (ret == 0) {
461  return ndone;
462  }
463  _buf += ret;
464  ndone += ret;
465  }
466  return ndone;
467  }
475  auto Send(const void *buf_, std::size_t len, std::int32_t flags = 0) {
476  const char *buf = reinterpret_cast<const char *>(buf_);
477  return send(handle_, buf, len, flags);
478  }
486  auto Recv(void *buf, std::size_t len, std::int32_t flags = 0) {
487  char *_buf = reinterpret_cast<char *>(buf);
488  return recv(handle_, _buf, len, flags);
489  }
493  std::size_t Send(StringView str);
497  std::size_t Recv(std::string *p_str);
501  void Close() {
502  if (InvalidSocket() != handle_) {
504  handle_ = InvalidSocket();
505  }
506  }
510  static TCPSocket Create(SockDomain domain) {
511 #if xgboost_IS_MINGW()
512  MingWError();
513  return {};
514 #else
515  auto fd = socket(static_cast<std::int32_t>(domain), SOCK_STREAM, 0);
516  if (fd == InvalidSocket()) {
517  system::ThrowAtError("socket");
518  }
519 
520  TCPSocket socket{fd};
521 #if defined(__APPLE__)
522  socket.domain_ = domain;
523 #endif // defined(__APPLE__)
524  return socket;
525 #endif // xgboost_IS_MINGW()
526  }
527 };
528 
533 std::error_code Connect(SockAddress const &addr, TCPSocket *out);
534 
538 inline std::string GetHostName() {
539  char buf[HOST_NAME_MAX];
540  xgboost_CHECK_SYS_CALL(gethostname(&buf[0], HOST_NAME_MAX), 0);
541  return buf;
542 }
543 } // namespace collective
544 } // namespace xgboost
545 
546 #undef xgboost_CHECK_SYS_CALL
547 #undef xgboost_IS_MINGW
defines configuration macros of xgboost.
Definition: socket.h:194
SockAddrV4(sockaddr_in addr)
Definition: socket.h:199
static SockAddrV4 InaddrAny()
in_port_t Port() const
Definition: socket.h:205
sockaddr_in const & Handle() const
Definition: socket.h:216
std::string Addr() const
Definition: socket.h:207
static SockAddrV4 Loopback()
SockAddrV4()
Definition: socket.h:200
Definition: socket.h:170
static SockAddrV6 InaddrAny()
SockAddrV6()
Definition: socket.h:175
sockaddr_in6 const & Handle() const
Definition: socket.h:191
in_port_t Port() const
Definition: socket.h:180
SockAddrV6(sockaddr_in6 addr)
Definition: socket.h:174
std::string Addr() const
Definition: socket.h:182
static SockAddrV6 Loopback()
Address for TCP socket, can be either IPv4 or IPv6.
Definition: socket.h:222
bool IsV6() const
Definition: socket.h:236
auto const & V6() const
Definition: socket.h:239
bool IsV4() const
Definition: socket.h:235
auto Domain() const
Definition: socket.h:233
SockAddress(SockAddrV4 const &addr)
Definition: socket.h:231
auto const & V4() const
Definition: socket.h:238
SockAddress(SockAddrV6 const &addr)
Definition: socket.h:230
TCP socket for simple communication.
Definition: socket.h:245
HandleT const & Handle() const
Return the native socket file descriptor.
Definition: socket.h:394
void SetNoDelay()
Definition: socket.h:358
std::int32_t GetSockError() const
get last error code if any
Definition: socket.h:315
TCPSocket & operator=(TCPSocket const &that)=delete
auto SendAll(void const *buf, std::size_t len)
Send data, without error then all data should be sent.
Definition: socket.h:430
system::SocketT HandleT
Definition: socket.h:247
void SetNonBlock()
Definition: socket.h:330
void Close()
Close the socket, called automatically in destructor if the socket is not closed.
Definition: socket.h:501
void Listen(std::int32_t backlog=16)
Listen to incoming requests. Should be called after bind.
Definition: socket.h:398
TCPSocket(TCPSocket const &that)=delete
std::size_t Send(StringView str)
Send string, format is matched with the Python socket wrapper in RABIT.
TCPSocket & operator=(TCPSocket &&that)
Definition: socket.h:387
auto Domain() const -> SockDomain
Return the socket domain.
Definition: socket.h:266
static TCPSocket Create(SockDomain domain)
Create a TCP socket on specified domain.
Definition: socket.h:510
auto Recv(void *buf, std::size_t len, std::int32_t flags=0)
receive data using the socket
Definition: socket.h:486
auto RecvAll(void *buf, std::size_t len)
Receive data, without error then all data should be received.
Definition: socket.h:449
bool IsClosed() const
Definition: socket.h:312
TCPSocket Accept()
Accept new connection, returns a new TCP socket for the new connection.
Definition: socket.h:369
~TCPSocket()
Definition: socket.h:378
auto Send(const void *buf_, std::size_t len, std::int32_t flags=0)
Send data using the socket.
Definition: socket.h:475
in_port_t BindHost()
Bind socket to INADDR_ANY, return the port selected by the OS.
Definition: socket.h:402
void SetKeepAlive()
Definition: socket.h:351
std::size_t Recv(std::string *p_str)
Receive string, format is matched with the Python socket wrapper in RABIT.
bool BadSocket() const
check if anything bad happens
Definition: socket.h:323
TCPSocket(TCPSocket &&that) noexcept(true)
Definition: socket.h:385
void swap(xgboost::IntrusivePtr< T > &x, xgboost::IntrusivePtr< T > &y) noexcept
Definition: intrusive_ptr.h:209
std::error_code Connect(SockAddress const &addr, TCPSocket *out)
Connect to remote address, returns the error code if failed (no exception is raised so that we can re...
std::string GetHostName()
Get the local host name.
Definition: socket.h:538
SockAddress MakeSockAddress(StringView host, in_port_t port)
Parse host address and return a SockAddress instance. Supports IPv4 and IPv6 host.
SockDomain
Definition: socket.h:162
auto ThrowAtError(StringView fn_name, std::int32_t errsv=LastError())
Definition: socket.h:90
void SocketStartup()
Definition: socket.h:128
std::int32_t CloseSocket(SocketT fd)
Definition: socket.h:111
bool LastErrorWouldBlock()
Definition: socket.h:119
std::int32_t LastError()
Definition: socket.h:71
void SocketFinalize()
Definition: socket.h:141
int SocketT
Definition: socket.h:99
namespace of xgboost
Definition: base.h:110
#define xgboost_CHECK_SYS_CALL(exp, expected)
Definition: socket.h:103
#define HOST_NAME_MAX
Definition: socket.h:60
Definition: string_view.h:15