6 #if !defined(NOMINMAX) && defined(_WIN32)
16 #include <system_error>
19 #if defined(__linux__)
20 #include <sys/ioctl.h>
23 #if !defined(xgboost_IS_MINGW)
25 #if defined(__MINGW32__)
26 #define xgboost_IS_MINGW 1
36 using in_port_t = std::uint16_t;
39 #pragma comment(lib, "Ws2_32.lib")
42 #if !defined(xgboost_IS_MINGW)
48 #include <arpa/inet.h>
50 #include <netinet/in.h>
51 #include <netinet/in.h>
52 #include <netinet/tcp.h>
53 #include <sys/socket.h>
56 #if defined(__sun) || defined(sun)
57 #include <sys/sockio.h>
64 #include "xgboost/logging.h"
67 #if !defined(HOST_NAME_MAX)
68 #define HOST_NAME_MAX 256
73 #if defined(xgboost_IS_MINGW)
75 inline void MingWError() { LOG(FATAL) <<
"Distributed training on mingw is not supported."; }
81 return WSAGetLastError();
92 #if defined(__GLIBC__)
96 auto err = std::error_code{errsv, std::system_category()};
98 << file <<
"(" << line <<
"): Failed to call `" << fn_name <<
"`: " << err.message()
103 auto err = std::error_code{errsv, std::system_category()};
104 LOG(FATAL) <<
"Failed to call `" << fn_name <<
"`: " << err.message() << std::endl;
114 #if !defined(xgboost_CHECK_SYS_CALL)
115 #define xgboost_CHECK_SYS_CALL(exp, expected) \
117 if (XGBOOST_EXPECT((exp) != (expected), false)) { \
118 ::xgboost::system::ThrowAtError(#exp); \
125 return closesocket(fd);
133 auto rc = shutdown(fd, SD_BOTH);
134 if (rc != 0 &&
LastError() == WSANOTINITIALISED) {
138 auto rc = shutdown(fd, SHUT_RDWR);
139 if (rc != 0 &&
LastError() == ENOTCONN) {
148 return errsv == WSAEWOULDBLOCK;
150 return errsv == EAGAIN || errsv == EWOULDBLOCK || errsv == EINPROGRESS;
162 if (WSAStartup(MAKEWORD(2, 2), &wsa_data) == -1) {
165 if (LOBYTE(wsa_data.wVersion) != 2 || HIBYTE(wsa_data.wVersion) != 2) {
167 LOG(FATAL) <<
"Could not find a usable version of Winsock.dll";
178 #if defined(_WIN32) && defined(xgboost_IS_MINGW)
180 inline const char *inet_ntop(
int,
const void *,
char *, socklen_t) {
190 namespace collective {
211 in_port_t
Port()
const {
return ntohs(addr_.sin6_port); }
214 char buf[INET6_ADDRSTRLEN];
215 auto const *s = system::inet_ntop(
static_cast<std::int32_t
>(
SockDomain::kV6), &addr_.sin6_addr,
216 buf, INET6_ADDRSTRLEN);
222 sockaddr_in6
const &
Handle()
const {
return addr_; }
236 [[nodiscard]] in_port_t
Port()
const {
return ntohs(addr_.sin_port); }
238 [[nodiscard]] std::string
Addr()
const {
239 char buf[INET_ADDRSTRLEN];
240 auto const *s = system::inet_ntop(
static_cast<std::int32_t
>(
SockDomain::kV4), &addr_.sin_addr,
241 buf, INET_ADDRSTRLEN);
247 [[nodiscard]] sockaddr_in
const &
Handle()
const {
return addr_; }
264 [[nodiscard]]
auto Domain()
const {
return domain_; }
267 [[nodiscard]]
bool IsV6()
const {
return !
IsV4(); }
269 [[nodiscard]]
auto const &
V4()
const {
return v4_; }
270 [[nodiscard]]
auto const &
V6()
const {
return v6_; }
281 HandleT handle_{InvalidSocket()};
282 bool non_blocking_{
false};
285 #if defined(__APPLE__)
289 constexpr
static HandleT InvalidSocket() {
return -1; }
299 auto ret_iafamily = [](std::int32_t domain) {
306 LOG(FATAL) <<
"Unknown IA family.";
313 WSAPROTOCOL_INFOA info;
314 socklen_t len =
sizeof(info);
316 getsockopt(handle_, SOL_SOCKET, SO_PROTOCOL_INFO,
reinterpret_cast<char *
>(&info), &len),
318 return ret_iafamily(info.iAddressFamily);
319 #elif defined(__APPLE__)
321 #elif defined(__unix__)
324 socklen_t len =
sizeof(domain);
326 getsockopt(this->
Handle(), SOL_SOCKET, SO_DOMAIN,
reinterpret_cast<char *
>(&domain), &len),
328 return ret_iafamily(domain);
331 socklen_t sizeofsa =
sizeof(sa);
333 if (sizeofsa <
sizeof(uchar_t) * 2) {
334 return ret_iafamily(AF_INET);
336 return ret_iafamily(sa.sa_family);
339 LOG(FATAL) <<
"Unknown platform.";
340 return ret_iafamily(AF_INET);
344 [[nodiscard]]
bool IsClosed()
const {
return handle_ == InvalidSocket(); }
348 std::int32_t optval = 0;
349 socklen_t len =
sizeof(optval);
350 auto ret = getsockopt(handle_, SOL_SOCKET, SO_ERROR,
reinterpret_cast<char *
>(&optval), &len);
353 return Fail(
"Failed to retrieve socket error.", std::move(errc));
356 auto errc = std::error_code{optval, std::system_category()};
357 return Fail(
"Socket error.", std::move(errc));
368 if (err.Code() == std::error_code{EBADF, std::system_category()} ||
369 err.Code() == std::error_code{EINTR, std::system_category()}) {
377 u_long mode = non_block ? 1 : 0;
378 if (ioctlsocket(handle_, FIONBIO, &mode) != NO_ERROR) {
382 std::int32_t flag = fcntl(handle_, F_GETFL, 0);
392 rc = fcntl(handle_, F_SETFL, flag);
397 non_blocking_ = non_block;
404 DWORD tv = timeout.count() * 1000;
406 setsockopt(
Handle(), SOL_SOCKET, SO_RCVTIMEO,
reinterpret_cast<char *
>(&tv),
sizeof(tv));
409 tv.tv_sec = timeout.count();
411 auto rc = setsockopt(
Handle(), SOL_SOCKET, SO_RCVTIMEO,
reinterpret_cast<char const *
>(&tv),
421 auto rc = setsockopt(this->
Handle(), SOL_SOCKET, SO_SNDBUF,
reinterpret_cast<char *
>(&n_bytes),
426 rc = setsockopt(this->
Handle(), SOL_SOCKET, SO_RCVBUF,
reinterpret_cast<char *
>(&n_bytes),
436 auto rc = getsockopt(this->
Handle(), SOL_SOCKET, SO_SNDBUF,
reinterpret_cast<char *
>(n_bytes),
438 if (rc != 0 || optlen !=
sizeof(std::int32_t)) {
445 auto rc = getsockopt(this->
Handle(), SOL_SOCKET, SO_RCVBUF,
reinterpret_cast<char *
>(n_bytes),
447 if (rc != 0 || optlen !=
sizeof(std::int32_t)) {
452 #if defined(__linux__)
453 [[nodiscard]]
Result PendingSendSize(std::int32_t *n_bytes)
const {
454 return ioctl(this->
Handle(), TIOCOUTQ, n_bytes) == 0 ?
Success()
457 [[nodiscard]] Result PendingRecvSize(std::int32_t *n_bytes)
const {
458 return ioctl(this->
Handle(), FIONREAD, n_bytes) == 0 ?
Success()
464 std::int32_t keepalive = 1;
465 auto rc = setsockopt(handle_, SOL_SOCKET, SO_KEEPALIVE,
reinterpret_cast<char *
>(&keepalive),
474 auto rc = setsockopt(handle_, IPPROTO_TCP, TCP_NODELAY,
reinterpret_cast<char *
>(&no_delay),
488 auto rc = this->
Accept(&newsock, &addr);
495 auto interrupt = WSAEINTR;
497 auto interrupt = EINTR;
500 struct sockaddr_in caddr;
501 socklen_t caddr_len =
sizeof(caddr);
502 HandleT newfd = accept(
Handle(),
reinterpret_cast<sockaddr *
>(&caddr), &caddr_len);
509 struct sockaddr_in6 caddr;
510 socklen_t caddr_len =
sizeof(caddr);
511 HandleT newfd = accept(
Handle(),
reinterpret_cast<sockaddr *
>(&caddr), &caddr_len);
531 auto rc = this->
Close();
533 LOG(WARNING) << rc.Report();
553 if (listen(handle_, backlog) != 0) {
566 auto handle =
reinterpret_cast<sockaddr
const *
>(&addr.Handle());
567 if (bind(handle_, handle,
sizeof(std::remove_reference_t<decltype(addr.Handle())>)) != 0) {
571 sockaddr_in6 res_addr;
572 socklen_t addrlen =
sizeof(res_addr);
573 if (getsockname(handle_,
reinterpret_cast<sockaddr *
>(&res_addr), &addrlen) != 0) {
576 *p_out = ntohs(res_addr.sin6_port);
579 auto handle =
reinterpret_cast<sockaddr
const *
>(&addr.Handle());
580 if (bind(handle_, handle,
sizeof(std::remove_reference_t<decltype(addr.Handle())>)) != 0) {
584 sockaddr_in res_addr;
585 socklen_t addrlen =
sizeof(res_addr);
586 if (getsockname(handle_,
reinterpret_cast<sockaddr *
>(&res_addr), &addrlen) != 0) {
589 *p_out = ntohs(res_addr.sin_port);
595 [[nodiscard]]
auto Port()
const {
597 sockaddr_in res_addr;
598 socklen_t addrlen =
sizeof(res_addr);
599 auto code = getsockname(handle_,
reinterpret_cast<sockaddr *
>(&res_addr), &addrlen);
603 return std::make_pair(
Success(), std::int32_t{ntohs(res_addr.sin_port)});
605 sockaddr_in6 res_addr;
606 socklen_t addrlen =
sizeof(res_addr);
607 auto code = getsockname(handle_,
reinterpret_cast<sockaddr *
>(&res_addr), &addrlen);
611 return std::make_pair(
Success(), std::int32_t{ntohs(res_addr.sin6_port)});
618 std::int32_t errc{0};
620 auto handle =
reinterpret_cast<sockaddr
const *
>(&addr.V4().Handle());
621 errc = bind(handle_, handle,
sizeof(std::remove_reference_t<decltype(addr.V4().Handle())>));
623 auto handle =
reinterpret_cast<sockaddr
const *
>(&addr.V6().Handle());
624 errc = bind(handle_, handle,
sizeof(std::remove_reference_t<decltype(addr.V6().Handle())>));
629 auto [rc, new_port] = this->
Port();
631 return std::move(rc);
640 [[nodiscard]]
Result SendAll(
void const *buf, std::size_t len, std::size_t *n_sent) {
641 char const *_buf =
reinterpret_cast<const char *
>(buf);
642 std::size_t &ndone = *n_sent;
644 while (ndone < len) {
645 ssize_t ret = send(handle_, _buf, len - ndone, 0);
660 [[nodiscard]]
Result RecvAll(
void *buf, std::size_t len, std::size_t *n_recv) {
661 char *_buf =
reinterpret_cast<char *
>(buf);
662 std::size_t &ndone = *n_recv;
664 while (ndone < len) {
665 ssize_t ret = recv(handle_, _buf, len - ndone, MSG_WAITALL);
687 auto Send(
const void *buf_, std::size_t len, std::int32_t flags = 0) {
688 const char *buf =
reinterpret_cast<const char *
>(buf_);
689 return send(handle_, buf, len, flags);
698 auto Recv(
void *buf, std::size_t len, std::int32_t flags = 0) {
699 char *_buf =
reinterpret_cast<char *
>(buf);
700 return recv(handle_, _buf, len, flags);
714 if (InvalidSocket() != handle_) {
726 handle_ = InvalidSocket();
754 #if defined(xgboost_IS_MINGW)
758 auto fd = socket(
static_cast<std::int32_t
>(domain), SOCK_STREAM, 0);
759 if (fd == InvalidSocket()) {
764 #if defined(__APPLE__)
765 socket.domain_ = domain;
772 #if defined(xgboost_IS_MINGW)
776 auto fd = socket(
static_cast<std::int32_t
>(domain), SOCK_STREAM, 0);
777 if (fd == InvalidSocket()) {
782 #if defined(__APPLE__)
783 socket->domain_ = domain;
803 std::chrono::seconds timeout,
814 template <
typename H>
816 std::string &ip = *p_out;
817 switch (host->h_addrtype) {
819 auto addr =
reinterpret_cast<struct in_addr *
>(host->h_addr_list[0]);
820 char str[INET_ADDRSTRLEN];
821 inet_ntop(AF_INET, addr, str, INET_ADDRSTRLEN);
826 auto addr =
reinterpret_cast<struct in6_addr *
>(host->h_addr_list[0]);
827 char str[INET6_ADDRSTRLEN];
828 inet_ntop(AF_INET6, addr, str, INET6_ADDRSTRLEN);
833 return Fail(
"Invalid address type.");
841 #undef xgboost_CHECK_SYS_CALL
843 #if defined(xgboost_IS_MINGW)
844 #undef xgboost_IS_MINGW
Defines configuration macros and basic types for xgboost.
SockAddrV4(sockaddr_in addr)
Definition: socket.h:230
static SockAddrV4 InaddrAny()
in_port_t Port() const
Definition: socket.h:236
sockaddr_in const & Handle() const
Definition: socket.h:247
std::string Addr() const
Definition: socket.h:238
static SockAddrV4 Loopback()
SockAddrV4()
Definition: socket.h:231
static SockAddrV6 InaddrAny()
SockAddrV6()
Definition: socket.h:206
sockaddr_in6 const & Handle() const
Definition: socket.h:222
in_port_t Port() const
Definition: socket.h:211
SockAddrV6(sockaddr_in6 addr)
Definition: socket.h:205
std::string Addr() const
Definition: socket.h:213
static SockAddrV6 Loopback()
Address for TCP socket, can be either IPv4 or IPv6.
Definition: socket.h:253
bool IsV6() const
Definition: socket.h:267
auto const & V6() const
Definition: socket.h:270
bool IsV4() const
Definition: socket.h:266
auto Domain() const
Definition: socket.h:264
SockAddress(SockAddrV4 const &addr)
Definition: socket.h:262
auto const & V4() const
Definition: socket.h:269
SockAddress(SockAddrV6 const &addr)
Definition: socket.h:261
TCP socket for simple communication.
Definition: socket.h:276
Result GetSockError() const
get last error code if any
Definition: socket.h:347
Result Recv(std::string *p_str)
Receive string, format is matched with the Python socket wrapper in RABIT.
Result RecvTimeout(std::chrono::seconds timeout)
Definition: socket.h:401
HandleT const & Handle() const
Return the native socket file descriptor.
Definition: socket.h:548
Result SetNoDelay(std::int32_t no_delay=1)
Definition: socket.h:473
TCPSocket & operator=(TCPSocket const &that)=delete
Result BindHost(std::int32_t *p_out)
Bind socket to INADDR_ANY, return the port selected by the OS.
Definition: socket.h:561
static TCPSocket * CreatePtr(SockDomain domain)
Definition: socket.h:771
Result Shutdown()
Call shutdown on the socket.
Definition: socket.h:733
system::SocketT HandleT
Definition: socket.h:278
Result Bind(StringView ip, std::int32_t *port)
Definition: socket.h:615
TCPSocket & operator=(TCPSocket &&that) noexcept(true)
Definition: socket.h:541
auto Port() const
Definition: socket.h:595
Result SetKeepAlive()
Definition: socket.h:463
Result Accept(TCPSocket *out, SockAddress *addr)
Definition: socket.h:493
TCPSocket(TCPSocket const &that)=delete
Result RecvBufSize(std::int32_t *n_bytes)
Definition: socket.h:443
Result SendBufSize(std::int32_t *n_bytes)
Definition: socket.h:434
std::size_t Send(StringView str)
Send string, format is matched with the Python socket wrapper in RABIT.
auto Domain() const -> SockDomain
Return the socket domain.
Definition: socket.h:298
Result Listen(std::int32_t backlog=16)
Listen to incoming requests. Should be called after bind.
Definition: socket.h:552
static TCPSocket Create(SockDomain domain)
Create a TCP socket on specified domain.
Definition: socket.h:753
Result RecvAll(void *buf, std::size_t len, std::size_t *n_recv)
Receive data, without error then all data should be received.
Definition: socket.h:660
auto Recv(void *buf, std::size_t len, std::int32_t flags=0)
receive data using the socket
Definition: socket.h:698
bool NonBlocking() const
Definition: socket.h:400
bool IsClosed() const
Definition: socket.h:344
Result NonBlocking(bool non_block)
Definition: socket.h:375
TCPSocket Accept()
Accept new connection, returns a new TCP socket for the new connection.
Definition: socket.h:485
~TCPSocket()
Definition: socket.h:529
auto Send(const void *buf_, std::size_t len, std::int32_t flags=0)
Send data using the socket.
Definition: socket.h:687
Result SendAll(void const *buf, std::size_t len, std::size_t *n_sent)
Send data, without error then all data should be sent.
Definition: socket.h:640
bool BadSocket() const
check if anything bad happens
Definition: socket.h:363
TCPSocket(TCPSocket &&that) noexcept(true)
Definition: socket.h:539
Result SetBufSize(std::int32_t n_bytes)
Definition: socket.h:420
Result Close()
Close the socket, called automatically in destructor if the socket is not closed.
Definition: socket.h:713
void swap(xgboost::IntrusivePtr< T > &x, xgboost::IntrusivePtr< T > &y) noexcept
Definition: intrusive_ptr.h:209
Result Connect(xgboost::StringView host, std::int32_t port, std::int32_t retry, std::chrono::seconds timeout, xgboost::collective::TCPSocket *out_conn)
Connect to remote address, returns the error code if failed.
Result GetHostName(std::string *p_out)
Get the local host name.
SockAddress MakeSockAddress(StringView host, in_port_t port)
Parse host address and return a SockAddress instance. Supports IPv4 and IPv6 host.
SockDomain
Definition: socket.h:193
auto Fail(std::string msg, char const *file=__builtin_FILE(), std::int32_t line=__builtin_LINE())
Return failure.
Definition: result.h:124
void SafeColl(Result const &rc)
Result INetNToP(H const &host, std::string *p_out)
inet_ntop
Definition: socket.h:815
auto Success() noexcept(true)
Return success.
Definition: result.h:120
bool ErrorWouldBlock(std::int32_t errsv) noexcept(true)
Definition: socket.h:146
auto ThrowAtError(StringView fn_name, std::int32_t errsv=LastError())
Definition: socket.h:102
void SocketStartup()
Definition: socket.h:159
std::int32_t CloseSocket(SocketT fd)
Definition: socket.h:123
bool LastErrorWouldBlock()
Definition: socket.h:154
std::int32_t LastError()
Definition: socket.h:79
void SocketFinalize()
Definition: socket.h:172
std::int32_t ShutdownSocket(SocketT fd)
Definition: socket.h:131
collective::Result FailWithCode(std::string msg)
Definition: socket.h:88
int SocketT
Definition: socket.h:111
Core data structure for multi-target trees.
Definition: base.h:87
#define __builtin_LINE()
Definition: result.h:57
#define __builtin_FILE()
Definition: result.h:56
#define xgboost_CHECK_SYS_CALL(exp, expected)
Definition: socket.h:115
Definition: string_view.h:16
An error type that's easier to handle than throwing dmlc exception. We can record and propagate the s...
Definition: result.h:67