6 #if !defined(NOMINMAX) && defined(_WIN32)
16 #include <system_error>
19 #if !defined(xgboost_IS_MINGW)
21 #if defined(__MINGW32__)
22 #define xgboost_IS_MINGW 1
32 using in_port_t = std::uint16_t;
35 #pragma comment(lib, "Ws2_32.lib")
38 #if !defined(xgboost_IS_MINGW)
44 #include <arpa/inet.h>
46 #include <netinet/in.h>
47 #include <netinet/in.h>
48 #include <netinet/tcp.h>
49 #include <sys/socket.h>
52 #if defined(__sun) || defined(sun)
53 #include <sys/sockio.h>
60 #include "xgboost/logging.h"
63 #if !defined(HOST_NAME_MAX)
64 #define HOST_NAME_MAX 256
69 #if defined(xgboost_IS_MINGW)
71 inline void MingWError() { LOG(FATAL) <<
"Distributed training on mingw is not supported."; }
77 return WSAGetLastError();
88 #if defined(__GLIBC__)
92 auto err = std::error_code{errsv, std::system_category()};
94 << file <<
"(" << line <<
"): Failed to call `" << fn_name <<
"`: " << err.message()
99 auto err = std::error_code{errsv, std::system_category()};
100 LOG(FATAL) <<
"Failed to call `" << fn_name <<
"`: " << err.message() << std::endl;
110 #if !defined(xgboost_CHECK_SYS_CALL)
111 #define xgboost_CHECK_SYS_CALL(exp, expected) \
113 if (XGBOOST_EXPECT((exp) != (expected), false)) { \
114 ::xgboost::system::ThrowAtError(#exp); \
121 return closesocket(fd);
129 auto rc = shutdown(fd, SD_BOTH);
130 if (rc != 0 &&
LastError() == WSANOTINITIALISED) {
134 auto rc = shutdown(fd, SHUT_RDWR);
135 if (rc != 0 &&
LastError() == ENOTCONN) {
144 return errsv == WSAEWOULDBLOCK;
146 return errsv == EAGAIN || errsv == EWOULDBLOCK || errsv == EINPROGRESS;
158 if (WSAStartup(MAKEWORD(2, 2), &wsa_data) == -1) {
161 if (LOBYTE(wsa_data.wVersion) != 2 || HIBYTE(wsa_data.wVersion) != 2) {
163 LOG(FATAL) <<
"Could not find a usable version of Winsock.dll";
174 #if defined(_WIN32) && defined(xgboost_IS_MINGW)
176 inline const char *inet_ntop(
int,
const void *,
char *, socklen_t) {
186 namespace collective {
207 in_port_t
Port()
const {
return ntohs(addr_.sin6_port); }
210 char buf[INET6_ADDRSTRLEN];
211 auto const *s = system::inet_ntop(
static_cast<std::int32_t
>(
SockDomain::kV6), &addr_.sin6_addr,
212 buf, INET6_ADDRSTRLEN);
218 sockaddr_in6
const &
Handle()
const {
return addr_; }
232 [[nodiscard]] in_port_t
Port()
const {
return ntohs(addr_.sin_port); }
234 [[nodiscard]] std::string
Addr()
const {
235 char buf[INET_ADDRSTRLEN];
236 auto const *s = system::inet_ntop(
static_cast<std::int32_t
>(
SockDomain::kV4), &addr_.sin_addr,
237 buf, INET_ADDRSTRLEN);
243 [[nodiscard]] sockaddr_in
const &
Handle()
const {
return addr_; }
260 [[nodiscard]]
auto Domain()
const {
return domain_; }
263 [[nodiscard]]
bool IsV6()
const {
return !
IsV4(); }
265 [[nodiscard]]
auto const &
V4()
const {
return v4_; }
266 [[nodiscard]]
auto const &
V6()
const {
return v6_; }
277 HandleT handle_{InvalidSocket()};
278 bool non_blocking_{
false};
281 #if defined(__APPLE__)
285 constexpr
static HandleT InvalidSocket() {
return -1; }
295 auto ret_iafamily = [](std::int32_t domain) {
302 LOG(FATAL) <<
"Unknown IA family.";
309 WSAPROTOCOL_INFOA info;
310 socklen_t len =
sizeof(info);
312 getsockopt(handle_, SOL_SOCKET, SO_PROTOCOL_INFO,
reinterpret_cast<char *
>(&info), &len),
314 return ret_iafamily(info.iAddressFamily);
315 #elif defined(__APPLE__)
317 #elif defined(__unix__)
320 socklen_t len =
sizeof(domain);
322 getsockopt(handle_, SOL_SOCKET, SO_DOMAIN,
reinterpret_cast<char *
>(&domain), &len), 0);
323 return ret_iafamily(domain);
326 socklen_t sizeofsa =
sizeof(sa);
328 if (sizeofsa <
sizeof(uchar_t) * 2) {
329 return ret_iafamily(AF_INET);
331 return ret_iafamily(sa.sa_family);
334 LOG(FATAL) <<
"Unknown platform.";
335 return ret_iafamily(AF_INET);
339 [[nodiscard]]
bool IsClosed()
const {
return handle_ == InvalidSocket(); }
343 std::int32_t optval = 0;
344 socklen_t len =
sizeof(optval);
345 auto ret = getsockopt(handle_, SOL_SOCKET, SO_ERROR,
reinterpret_cast<char *
>(&optval), &len);
348 return Fail(
"Failed to retrieve socket error.", std::move(errc));
351 auto errc = std::error_code{optval, std::system_category()};
352 return Fail(
"Socket error.", std::move(errc));
363 if (err.Code() == std::error_code{EBADF, std::system_category()} ||
364 err.Code() == std::error_code{EINTR, std::system_category()}) {
372 u_long mode = non_block ? 1 : 0;
373 if (ioctlsocket(handle_, FIONBIO, &mode) != NO_ERROR) {
377 std::int32_t flag = fcntl(handle_, F_GETFL, 0);
387 rc = fcntl(handle_, F_SETFL, flag);
392 non_blocking_ = non_block;
399 DWORD tv = timeout.count() * 1000;
401 setsockopt(
Handle(), SOL_SOCKET, SO_RCVTIMEO,
reinterpret_cast<char *
>(&tv),
sizeof(tv));
404 tv.tv_sec = timeout.count();
406 auto rc = setsockopt(
Handle(), SOL_SOCKET, SO_RCVTIMEO,
reinterpret_cast<char const *
>(&tv),
416 auto rc = setsockopt(this->
Handle(), SOL_SOCKET, SO_SNDBUF,
reinterpret_cast<char *
>(&n_bytes),
421 rc = setsockopt(this->
Handle(), SOL_SOCKET, SO_RCVBUF,
reinterpret_cast<char *
>(&n_bytes),
430 std::int32_t keepalive = 1;
431 auto rc = setsockopt(handle_, SOL_SOCKET, SO_KEEPALIVE,
reinterpret_cast<char *
>(&keepalive),
440 std::int32_t tcp_no_delay = 1;
441 auto rc = setsockopt(handle_, IPPROTO_TCP, TCP_NODELAY,
reinterpret_cast<char *
>(&tcp_no_delay),
442 sizeof(tcp_no_delay));
455 auto rc = this->
Accept(&newsock, &addr);
462 auto interrupt = WSAEINTR;
464 auto interrupt = EINTR;
467 struct sockaddr_in caddr;
468 socklen_t caddr_len =
sizeof(caddr);
469 HandleT newfd = accept(
Handle(),
reinterpret_cast<sockaddr *
>(&caddr), &caddr_len);
476 struct sockaddr_in6 caddr;
477 socklen_t caddr_len =
sizeof(caddr);
478 HandleT newfd = accept(
Handle(),
reinterpret_cast<sockaddr *
>(&caddr), &caddr_len);
498 auto rc = this->
Close();
500 LOG(WARNING) << rc.Report();
520 if (listen(handle_, backlog) != 0) {
533 auto handle =
reinterpret_cast<sockaddr
const *
>(&addr.Handle());
534 if (bind(handle_, handle,
sizeof(std::remove_reference_t<decltype(addr.Handle())>)) != 0) {
538 sockaddr_in6 res_addr;
539 socklen_t addrlen =
sizeof(res_addr);
540 if (getsockname(handle_,
reinterpret_cast<sockaddr *
>(&res_addr), &addrlen) != 0) {
543 *p_out = ntohs(res_addr.sin6_port);
546 auto handle =
reinterpret_cast<sockaddr
const *
>(&addr.Handle());
547 if (bind(handle_, handle,
sizeof(std::remove_reference_t<decltype(addr.Handle())>)) != 0) {
551 sockaddr_in res_addr;
552 socklen_t addrlen =
sizeof(res_addr);
553 if (getsockname(handle_,
reinterpret_cast<sockaddr *
>(&res_addr), &addrlen) != 0) {
556 *p_out = ntohs(res_addr.sin_port);
562 [[nodiscard]]
auto Port()
const {
564 sockaddr_in res_addr;
565 socklen_t addrlen =
sizeof(res_addr);
566 auto code = getsockname(handle_,
reinterpret_cast<sockaddr *
>(&res_addr), &addrlen);
570 return std::make_pair(
Success(), std::int32_t{ntohs(res_addr.sin_port)});
572 sockaddr_in6 res_addr;
573 socklen_t addrlen =
sizeof(res_addr);
574 auto code = getsockname(handle_,
reinterpret_cast<sockaddr *
>(&res_addr), &addrlen);
578 return std::make_pair(
Success(), std::int32_t{ntohs(res_addr.sin6_port)});
585 std::int32_t errc{0};
587 auto handle =
reinterpret_cast<sockaddr
const *
>(&addr.V4().Handle());
588 errc = bind(handle_, handle,
sizeof(std::remove_reference_t<decltype(addr.V4().Handle())>));
590 auto handle =
reinterpret_cast<sockaddr
const *
>(&addr.V6().Handle());
591 errc = bind(handle_, handle,
sizeof(std::remove_reference_t<decltype(addr.V6().Handle())>));
596 auto [rc, new_port] = this->
Port();
598 return std::move(rc);
607 [[nodiscard]]
auto SendAll(
void const *buf, std::size_t len) {
608 char const *_buf =
reinterpret_cast<const char *
>(buf);
609 std::size_t ndone = 0;
610 while (ndone < len) {
611 ssize_t ret = send(handle_, _buf, len - ndone, 0);
626 [[nodiscard]]
auto RecvAll(
void *buf, std::size_t len) {
627 char *_buf =
reinterpret_cast<char *
>(buf);
628 std::size_t ndone = 0;
629 while (ndone < len) {
630 ssize_t ret = recv(handle_, _buf, len - ndone, MSG_WAITALL);
652 auto Send(
const void *buf_, std::size_t len, std::int32_t flags = 0) {
653 const char *buf =
reinterpret_cast<const char *
>(buf_);
654 return send(handle_, buf, len, flags);
663 auto Recv(
void *buf, std::size_t len, std::int32_t flags = 0) {
664 char *_buf =
reinterpret_cast<char *
>(buf);
665 return recv(handle_, _buf, len, flags);
679 if (InvalidSocket() != handle_) {
691 handle_ = InvalidSocket();
719 #if defined(xgboost_IS_MINGW)
723 auto fd = socket(
static_cast<std::int32_t
>(domain), SOCK_STREAM, 0);
724 if (fd == InvalidSocket()) {
729 #if defined(__APPLE__)
730 socket.domain_ = domain;
737 #if defined(xgboost_IS_MINGW)
741 auto fd = socket(
static_cast<std::int32_t
>(domain), SOCK_STREAM, 0);
742 if (fd == InvalidSocket()) {
747 #if defined(__APPLE__)
748 socket->domain_ = domain;
768 std::chrono::seconds timeout,
779 template <
typename H>
781 std::string &ip = *p_out;
782 switch (host->h_addrtype) {
784 auto addr =
reinterpret_cast<struct in_addr *
>(host->h_addr_list[0]);
785 char str[INET_ADDRSTRLEN];
786 inet_ntop(AF_INET, addr, str, INET_ADDRSTRLEN);
791 auto addr =
reinterpret_cast<struct in6_addr *
>(host->h_addr_list[0]);
792 char str[INET6_ADDRSTRLEN];
793 inet_ntop(AF_INET6, addr, str, INET6_ADDRSTRLEN);
798 return Fail(
"Invalid address type.");
806 #undef xgboost_CHECK_SYS_CALL
808 #if defined(xgboost_IS_MINGW)
809 #undef xgboost_IS_MINGW
Defines configuration macros and basic types for xgboost.
SockAddrV4(sockaddr_in addr)
Definition: socket.h:226
static SockAddrV4 InaddrAny()
in_port_t Port() const
Definition: socket.h:232
sockaddr_in const & Handle() const
Definition: socket.h:243
std::string Addr() const
Definition: socket.h:234
static SockAddrV4 Loopback()
SockAddrV4()
Definition: socket.h:227
static SockAddrV6 InaddrAny()
SockAddrV6()
Definition: socket.h:202
sockaddr_in6 const & Handle() const
Definition: socket.h:218
in_port_t Port() const
Definition: socket.h:207
SockAddrV6(sockaddr_in6 addr)
Definition: socket.h:201
std::string Addr() const
Definition: socket.h:209
static SockAddrV6 Loopback()
Address for TCP socket, can be either IPv4 or IPv6.
Definition: socket.h:249
bool IsV6() const
Definition: socket.h:263
auto const & V6() const
Definition: socket.h:266
bool IsV4() const
Definition: socket.h:262
auto Domain() const
Definition: socket.h:260
SockAddress(SockAddrV4 const &addr)
Definition: socket.h:258
auto const & V4() const
Definition: socket.h:265
SockAddress(SockAddrV6 const &addr)
Definition: socket.h:257
TCP socket for simple communication.
Definition: socket.h:272
Result GetSockError() const
get last error code if any
Definition: socket.h:342
Result Recv(std::string *p_str)
Receive string, format is matched with the Python socket wrapper in RABIT.
Result RecvTimeout(std::chrono::seconds timeout)
Definition: socket.h:396
HandleT const & Handle() const
Return the native socket file descriptor.
Definition: socket.h:515
TCPSocket & operator=(TCPSocket const &that)=delete
auto SendAll(void const *buf, std::size_t len)
Send data, without error then all data should be sent.
Definition: socket.h:607
Result BindHost(std::int32_t *p_out)
Bind socket to INADDR_ANY, return the port selected by the OS.
Definition: socket.h:528
static TCPSocket * CreatePtr(SockDomain domain)
Definition: socket.h:736
Result SetNoDelay()
Definition: socket.h:439
Result Shutdown()
Call shutdown on the socket.
Definition: socket.h:698
system::SocketT HandleT
Definition: socket.h:274
Result Bind(StringView ip, std::int32_t *port)
Definition: socket.h:582
TCPSocket & operator=(TCPSocket &&that) noexcept(true)
Definition: socket.h:508
auto Port() const
Definition: socket.h:562
Result SetKeepAlive()
Definition: socket.h:429
Result Accept(TCPSocket *out, SockAddress *addr)
Definition: socket.h:460
TCPSocket(TCPSocket const &that)=delete
std::size_t Send(StringView str)
Send string, format is matched with the Python socket wrapper in RABIT.
auto Domain() const -> SockDomain
Return the socket domain.
Definition: socket.h:294
Result Listen(std::int32_t backlog=16)
Listen to incoming requests. Should be called after bind.
Definition: socket.h:519
static TCPSocket Create(SockDomain domain)
Create a TCP socket on specified domain.
Definition: socket.h:718
auto Recv(void *buf, std::size_t len, std::int32_t flags=0)
receive data using the socket
Definition: socket.h:663
auto RecvAll(void *buf, std::size_t len)
Receive data, without error then all data should be received.
Definition: socket.h:626
bool NonBlocking() const
Definition: socket.h:395
bool IsClosed() const
Definition: socket.h:339
Result NonBlocking(bool non_block)
Definition: socket.h:370
TCPSocket Accept()
Accept new connection, returns a new TCP socket for the new connection.
Definition: socket.h:452
~TCPSocket()
Definition: socket.h:496
auto Send(const void *buf_, std::size_t len, std::int32_t flags=0)
Send data using the socket.
Definition: socket.h:652
bool BadSocket() const
check if anything bad happens
Definition: socket.h:358
TCPSocket(TCPSocket &&that) noexcept(true)
Definition: socket.h:506
Result SetBufSize(std::int32_t n_bytes)
Definition: socket.h:415
Result Close()
Close the socket, called automatically in destructor if the socket is not closed.
Definition: socket.h:678
void swap(xgboost::IntrusivePtr< T > &x, xgboost::IntrusivePtr< T > &y) noexcept
Definition: intrusive_ptr.h:209
Result Connect(xgboost::StringView host, std::int32_t port, std::int32_t retry, std::chrono::seconds timeout, xgboost::collective::TCPSocket *out_conn)
Connect to remote address, returns the error code if failed.
Result GetHostName(std::string *p_out)
Get the local host name.
SockAddress MakeSockAddress(StringView host, in_port_t port)
Parse host address and return a SockAddress instance. Supports IPv4 and IPv6 host.
SockDomain
Definition: socket.h:189
auto Fail(std::string msg, char const *file=__builtin_FILE(), std::int32_t line=__builtin_LINE())
Return failure.
Definition: result.h:125
void SafeColl(Result const &rc)
Result INetNToP(H const &host, std::string *p_out)
inet_ntop
Definition: socket.h:780
auto Success() noexcept(true)
Return success.
Definition: result.h:121
bool ErrorWouldBlock(std::int32_t errsv) noexcept(true)
Definition: socket.h:142
auto ThrowAtError(StringView fn_name, std::int32_t errsv=LastError())
Definition: socket.h:98
void SocketStartup()
Definition: socket.h:155
std::int32_t CloseSocket(SocketT fd)
Definition: socket.h:119
bool LastErrorWouldBlock()
Definition: socket.h:150
std::int32_t LastError()
Definition: socket.h:75
void SocketFinalize()
Definition: socket.h:168
std::int32_t ShutdownSocket(SocketT fd)
Definition: socket.h:127
collective::Result FailWithCode(std::string msg)
Definition: socket.h:84
int SocketT
Definition: socket.h:107
Core data structure for multi-target trees.
Definition: base.h:87
#define __builtin_LINE()
Definition: result.h:57
#define __builtin_FILE()
Definition: result.h:56
#define xgboost_CHECK_SYS_CALL(exp, expected)
Definition: socket.h:111
Definition: string_view.h:16
An error type that's easier to handle than throwing dmlc exception. We can record and propagate the s...
Definition: result.h:68