12 #include <system_error>
15 #if defined(__linux__)
16 #include <sys/ioctl.h>
26 using in_port_t = std::uint16_t;
29 #pragma comment(lib, "Ws2_32.lib")
32 #if !defined(xgboost_IS_MINGW)
38 #include <arpa/inet.h>
40 #include <netinet/in.h>
41 #include <netinet/in.h>
42 #include <netinet/tcp.h>
43 #include <sys/socket.h>
46 #if defined(__sun) || defined(sun)
47 #include <sys/sockio.h>
54 #include "xgboost/logging.h"
57 #if !defined(HOST_NAME_MAX)
58 #define HOST_NAME_MAX 256
63 #if defined(xgboost_IS_MINGW)
65 inline void MingWError() { LOG(FATAL) <<
"Distributed training on mingw is not supported."; }
71 return WSAGetLastError();
82 #if defined(__GLIBC__)
86 auto err = std::error_code{errsv, std::system_category()};
88 << file <<
"(" << line <<
"): Failed to call `" << fn_name <<
"`: " << err.message()
93 auto err = std::error_code{errsv, std::system_category()};
94 LOG(FATAL) <<
"Failed to call `" << fn_name <<
"`: " << err.message() << std::endl;
102 #define INVALID_SOCKET -1
105 #if !defined(xgboost_CHECK_SYS_CALL)
106 #define xgboost_CHECK_SYS_CALL(exp, expected) \
108 if (XGBOOST_EXPECT((exp) != (expected), false)) { \
109 ::xgboost::system::ThrowAtError(#exp); \
116 return closesocket(fd);
124 auto rc = shutdown(fd, SD_BOTH);
125 if (rc != 0 &&
LastError() == WSANOTINITIALISED) {
129 auto rc = shutdown(fd, SHUT_RDWR);
130 if (rc != 0 &&
LastError() == ENOTCONN) {
139 return errsv == WSAEWOULDBLOCK;
141 return errsv == EAGAIN || errsv == EWOULDBLOCK || errsv == EINPROGRESS;
153 if (WSAStartup(MAKEWORD(2, 2), &wsa_data) == -1) {
156 if (LOBYTE(wsa_data.wVersion) != 2 || HIBYTE(wsa_data.wVersion) != 2) {
158 LOG(FATAL) <<
"Could not find a usable version of Winsock.dll";
169 #if defined(_WIN32) && defined(xgboost_IS_MINGW)
171 inline const char *inet_ntop(
int,
const void *,
char *, socklen_t) {
181 namespace collective {
202 in_port_t
Port()
const {
return ntohs(addr_.sin6_port); }
205 char buf[INET6_ADDRSTRLEN];
206 auto const *s = system::inet_ntop(
static_cast<std::int32_t
>(
SockDomain::kV6), &addr_.sin6_addr,
207 buf, INET6_ADDRSTRLEN);
213 sockaddr_in6
const &
Handle()
const {
return addr_; }
227 [[nodiscard]] in_port_t
Port()
const {
return ntohs(addr_.sin_port); }
229 [[nodiscard]] std::string
Addr()
const {
230 char buf[INET_ADDRSTRLEN];
231 auto const *s = system::inet_ntop(
static_cast<std::int32_t
>(
SockDomain::kV4), &addr_.sin_addr,
232 buf, INET_ADDRSTRLEN);
238 [[nodiscard]] sockaddr_in
const &
Handle()
const {
return addr_; }
255 [[nodiscard]]
auto Domain()
const {
return domain_; }
258 [[nodiscard]]
bool IsV6()
const {
return !
IsV4(); }
260 [[nodiscard]]
auto const &
V4()
const {
return v4_; }
261 [[nodiscard]]
auto const &
V6()
const {
return v6_; }
272 HandleT handle_{InvalidSocket()};
273 bool non_blocking_{
false};
276 #if defined(__APPLE__)
290 auto ret_iafamily = [](std::int32_t domain) {
297 LOG(FATAL) <<
"Unknown IA family.";
304 WSAPROTOCOL_INFOA info;
305 socklen_t len =
sizeof(info);
307 getsockopt(handle_, SOL_SOCKET, SO_PROTOCOL_INFO,
reinterpret_cast<char *
>(&info), &len),
309 return ret_iafamily(info.iAddressFamily);
310 #elif defined(__APPLE__)
312 #elif defined(__unix__)
315 socklen_t len =
sizeof(domain);
317 getsockopt(this->
Handle(), SOL_SOCKET, SO_DOMAIN,
reinterpret_cast<char *
>(&domain), &len),
319 return ret_iafamily(domain);
322 socklen_t sizeofsa =
sizeof(sa);
324 if (sizeofsa <
sizeof(uchar_t) * 2) {
325 return ret_iafamily(AF_INET);
327 return ret_iafamily(sa.sa_family);
330 LOG(FATAL) <<
"Unknown platform.";
331 return ret_iafamily(AF_INET);
335 [[nodiscard]]
bool IsClosed()
const {
return handle_ == InvalidSocket(); }
339 std::int32_t optval = 0;
340 socklen_t len =
sizeof(optval);
341 auto ret = getsockopt(handle_, SOL_SOCKET, SO_ERROR,
reinterpret_cast<char *
>(&optval), &len);
344 return Fail(
"Failed to retrieve socket error.", std::move(errc));
347 auto errc = std::error_code{optval, std::system_category()};
348 return Fail(
"Socket error.", std::move(errc));
359 if (err.Code() == std::error_code{EBADF, std::system_category()} ||
360 err.Code() == std::error_code{EINTR, std::system_category()}) {
368 u_long mode = non_block ? 1 : 0;
369 if (ioctlsocket(handle_, FIONBIO, &mode) != NO_ERROR) {
373 std::int32_t flag = fcntl(handle_, F_GETFL, 0);
383 rc = fcntl(handle_, F_SETFL, flag);
388 non_blocking_ = non_block;
395 DWORD tv = timeout.count() * 1000;
397 setsockopt(
Handle(), SOL_SOCKET, SO_RCVTIMEO,
reinterpret_cast<char *
>(&tv),
sizeof(tv));
400 tv.tv_sec = timeout.count();
402 auto rc = setsockopt(
Handle(), SOL_SOCKET, SO_RCVTIMEO,
reinterpret_cast<char const *
>(&tv),
412 auto rc = setsockopt(this->
Handle(), SOL_SOCKET, SO_SNDBUF,
reinterpret_cast<char *
>(&n_bytes),
417 rc = setsockopt(this->
Handle(), SOL_SOCKET, SO_RCVBUF,
reinterpret_cast<char *
>(&n_bytes),
427 auto rc = getsockopt(this->
Handle(), SOL_SOCKET, SO_SNDBUF,
reinterpret_cast<char *
>(n_bytes),
429 if (rc != 0 || optlen !=
sizeof(std::int32_t)) {
436 auto rc = getsockopt(this->
Handle(), SOL_SOCKET, SO_RCVBUF,
reinterpret_cast<char *
>(n_bytes),
438 if (rc != 0 || optlen !=
sizeof(std::int32_t)) {
443 #if defined(__linux__)
444 [[nodiscard]]
Result PendingSendSize(std::int32_t *n_bytes)
const {
445 return ioctl(this->
Handle(), TIOCOUTQ, n_bytes) == 0 ?
Success()
448 [[nodiscard]] Result PendingRecvSize(std::int32_t *n_bytes)
const {
449 return ioctl(this->
Handle(), FIONREAD, n_bytes) == 0 ?
Success()
455 std::int32_t keepalive = 1;
456 auto rc = setsockopt(handle_, SOL_SOCKET, SO_KEEPALIVE,
reinterpret_cast<char *
>(&keepalive),
465 auto rc = setsockopt(handle_, IPPROTO_TCP, TCP_NODELAY,
reinterpret_cast<char *
>(&no_delay),
479 auto rc = this->
Accept(&newsock, &addr);
486 auto interrupt = WSAEINTR;
488 auto interrupt = EINTR;
491 struct sockaddr_in caddr;
492 socklen_t caddr_len =
sizeof(caddr);
493 HandleT newfd = accept(
Handle(),
reinterpret_cast<sockaddr *
>(&caddr), &caddr_len);
500 struct sockaddr_in6 caddr;
501 socklen_t caddr_len =
sizeof(caddr);
502 HandleT newfd = accept(
Handle(),
reinterpret_cast<sockaddr *
>(&caddr), &caddr_len);
522 auto rc = this->
Close();
524 LOG(WARNING) << rc.Report();
554 auto handle =
reinterpret_cast<sockaddr
const *
>(&addr.Handle());
555 if (bind(handle_, handle,
sizeof(std::remove_reference_t<decltype(addr.Handle())>)) != 0) {
559 sockaddr_in6 res_addr;
560 socklen_t addrlen =
sizeof(res_addr);
561 if (getsockname(handle_,
reinterpret_cast<sockaddr *
>(&res_addr), &addrlen) != 0) {
564 *p_out = ntohs(res_addr.sin6_port);
567 auto handle =
reinterpret_cast<sockaddr
const *
>(&addr.Handle());
568 if (bind(handle_, handle,
sizeof(std::remove_reference_t<decltype(addr.Handle())>)) != 0) {
572 sockaddr_in res_addr;
573 socklen_t addrlen =
sizeof(res_addr);
574 if (getsockname(handle_,
reinterpret_cast<sockaddr *
>(&res_addr), &addrlen) != 0) {
577 *p_out = ntohs(res_addr.sin_port);
583 [[nodiscard]]
auto Port()
const {
585 sockaddr_in res_addr;
586 socklen_t addrlen =
sizeof(res_addr);
587 auto code = getsockname(handle_,
reinterpret_cast<sockaddr *
>(&res_addr), &addrlen);
591 return std::make_pair(
Success(), std::int32_t{ntohs(res_addr.sin_port)});
593 sockaddr_in6 res_addr;
594 socklen_t addrlen =
sizeof(res_addr);
595 auto code = getsockname(handle_,
reinterpret_cast<sockaddr *
>(&res_addr), &addrlen);
599 return std::make_pair(
Success(), std::int32_t{ntohs(res_addr.sin6_port)});
611 std::int32_t errc{0};
613 auto handle =
reinterpret_cast<sockaddr
const *
>(&addr.V4().Handle());
614 errc = bind(handle_, handle,
sizeof(std::remove_reference_t<decltype(addr.V4().Handle())>));
616 auto handle =
reinterpret_cast<sockaddr
const *
>(&addr.V6().Handle());
617 errc = bind(handle_, handle,
sizeof(std::remove_reference_t<decltype(addr.V6().Handle())>));
622 auto [rc, new_port] = this->
Port();
624 return std::move(rc);
630 if (*port != new_port) {
631 return Fail(
"Got an invalid port from bind.");
639 [[nodiscard]]
Result SendAll(
void const *buf, std::size_t len, std::size_t *n_sent) {
640 char const *_buf =
reinterpret_cast<const char *
>(buf);
641 std::size_t &ndone = *n_sent;
643 while (ndone < len) {
644 ssize_t ret = send(handle_, _buf, len - ndone, 0);
659 [[nodiscard]]
Result RecvAll(
void *buf, std::size_t len, std::size_t *n_recv) {
660 char *_buf =
reinterpret_cast<char *
>(buf);
661 std::size_t &ndone = *n_recv;
663 while (ndone < len) {
664 ssize_t ret = recv(handle_, _buf, len - ndone, MSG_WAITALL);
686 auto Send(
const void *buf_, std::size_t len, std::int32_t flags = 0) {
687 const char *buf =
reinterpret_cast<const char *
>(buf_);
688 return send(handle_, buf, len, flags);
697 auto Recv(
void *buf, std::size_t len, std::int32_t flags = 0) {
698 char *_buf =
static_cast<char *
>(buf);
701 return recv(handle_, _buf, len, flags);
716 if (InvalidSocket() != handle_) {
728 handle_ = InvalidSocket();
756 #if defined(xgboost_IS_MINGW)
760 auto fd = socket(
static_cast<std::int32_t
>(domain), SOCK_STREAM, 0);
761 if (fd == InvalidSocket()) {
766 #if defined(__APPLE__)
767 socket.domain_ = domain;
774 #if defined(xgboost_IS_MINGW)
778 auto fd = socket(
static_cast<std::int32_t
>(domain), SOCK_STREAM, 0);
779 if (fd == InvalidSocket()) {
784 #if defined(__APPLE__)
785 socket->domain_ = domain;
805 std::chrono::seconds timeout,
816 template <
typename H>
818 std::string &ip = *p_out;
819 switch (host->h_addrtype) {
821 auto addr =
reinterpret_cast<struct in_addr *
>(host->h_addr_list[0]);
822 char str[INET_ADDRSTRLEN];
823 inet_ntop(AF_INET, addr, str, INET_ADDRSTRLEN);
828 auto addr =
reinterpret_cast<struct in6_addr *
>(host->h_addr_list[0]);
829 char str[INET6_ADDRSTRLEN];
830 inet_ntop(AF_INET6, addr, str, INET6_ADDRSTRLEN);
835 return Fail(
"Invalid address type.");
843 #undef xgboost_CHECK_SYS_CALL
Defines configuration macros and basic types for xgboost.
SockAddrV4(sockaddr_in addr)
Definition: socket.h:221
static SockAddrV4 InaddrAny()
in_port_t Port() const
Definition: socket.h:227
sockaddr_in const & Handle() const
Definition: socket.h:238
std::string Addr() const
Definition: socket.h:229
static SockAddrV4 Loopback()
SockAddrV4()
Definition: socket.h:222
static SockAddrV6 InaddrAny()
SockAddrV6()
Definition: socket.h:197
sockaddr_in6 const & Handle() const
Definition: socket.h:213
in_port_t Port() const
Definition: socket.h:202
SockAddrV6(sockaddr_in6 addr)
Definition: socket.h:196
std::string Addr() const
Definition: socket.h:204
static SockAddrV6 Loopback()
Address for TCP socket, can be either IPv4 or IPv6.
Definition: socket.h:244
bool IsV6() const
Definition: socket.h:258
auto const & V6() const
Definition: socket.h:261
bool IsV4() const
Definition: socket.h:257
auto Domain() const
Definition: socket.h:255
SockAddress(SockAddrV4 const &addr)
Definition: socket.h:253
auto const & V4() const
Definition: socket.h:260
SockAddress(SockAddrV6 const &addr)
Definition: socket.h:252
TCP socket for simple communication.
Definition: socket.h:267
Result GetSockError() const
get last error code if any
Definition: socket.h:338
Result Recv(std::string *p_str)
Receive string, format is matched with the Python socket wrapper in RABIT.
Result RecvTimeout(std::chrono::seconds timeout)
Definition: socket.h:392
HandleT const & Handle() const
Return the native socket file descriptor.
Definition: socket.h:539
Result SetNoDelay(std::int32_t no_delay=1)
Definition: socket.h:464
TCPSocket & operator=(TCPSocket const &that)=delete
Result BindHost(std::int32_t *p_out)
Bind socket to INADDR_ANY, return the port selected by the OS.
Definition: socket.h:549
Result Listen(std::int32_t backlog=256)
Listen to incoming requests. Should be called after bind.
static TCPSocket * CreatePtr(SockDomain domain)
Definition: socket.h:773
Result Shutdown()
Call shutdown on the socket.
Definition: socket.h:735
system::SocketT HandleT
Definition: socket.h:269
Result Bind(StringView ip, std::int32_t *port)
Bind the socket to the address.
Definition: socket.h:608
TCPSocket & operator=(TCPSocket &&that) noexcept(true)
Definition: socket.h:532
auto Port() const
Definition: socket.h:583
Result SetKeepAlive()
Definition: socket.h:454
Result Accept(TCPSocket *out, SockAddress *addr)
Definition: socket.h:484
TCPSocket(TCPSocket const &that)=delete
Result RecvBufSize(std::int32_t *n_bytes)
Definition: socket.h:434
Result SendBufSize(std::int32_t *n_bytes)
Definition: socket.h:425
std::size_t Send(StringView str)
Send string, format is matched with the Python socket wrapper in RABIT.
auto Domain() const -> SockDomain
Return the socket domain.
Definition: socket.h:289
static TCPSocket Create(SockDomain domain)
Create a TCP socket on specified domain.
Definition: socket.h:755
Result RecvAll(void *buf, std::size_t len, std::size_t *n_recv)
Receive data, without error then all data should be received.
Definition: socket.h:659
auto Recv(void *buf, std::size_t len, std::int32_t flags=0)
receive data using the socket
Definition: socket.h:697
bool NonBlocking() const
Definition: socket.h:391
bool IsClosed() const
Definition: socket.h:335
Result NonBlocking(bool non_block)
Definition: socket.h:366
TCPSocket Accept()
Accept new connection, returns a new TCP socket for the new connection.
Definition: socket.h:476
~TCPSocket()
Definition: socket.h:520
auto Send(const void *buf_, std::size_t len, std::int32_t flags=0)
Send data using the socket.
Definition: socket.h:686
Result SendAll(void const *buf, std::size_t len, std::size_t *n_sent)
Send data, without error then all data should be sent.
Definition: socket.h:639
bool BadSocket() const
check if anything bad happens
Definition: socket.h:354
TCPSocket(TCPSocket &&that) noexcept(true)
Definition: socket.h:530
Result SetBufSize(std::int32_t n_bytes)
Definition: socket.h:411
Result Close()
Close the socket, called automatically in destructor if the socket is not closed.
Definition: socket.h:715
void swap(xgboost::IntrusivePtr< T > &x, xgboost::IntrusivePtr< T > &y) noexcept
Definition: intrusive_ptr.h:209
Result Connect(xgboost::StringView host, std::int32_t port, std::int32_t retry, std::chrono::seconds timeout, xgboost::collective::TCPSocket *out_conn)
Connect to remote address, returns the error code if failed.
Result GetHostName(std::string *p_out)
Get the local host name.
SockAddress MakeSockAddress(StringView host, in_port_t port)
Parse host address and return a SockAddress instance. Supports IPv4 and IPv6 host.
SockDomain
Definition: socket.h:184
auto Fail(std::string msg, char const *file=__builtin_FILE(), std::int32_t line=__builtin_LINE())
Return failure.
Definition: result.h:124
void SafeColl(Result const &rc)
Result INetNToP(H const &host, std::string *p_out)
inet_ntop
Definition: socket.h:817
auto Success() noexcept(true)
Return success.
Definition: result.h:120
bool ErrorWouldBlock(std::int32_t errsv) noexcept(true)
Definition: socket.h:137
auto ThrowAtError(StringView fn_name, std::int32_t errsv=LastError())
Definition: socket.h:92
void SocketStartup()
Definition: socket.h:150
std::int32_t CloseSocket(SocketT fd)
Definition: socket.h:114
bool LastErrorWouldBlock()
Definition: socket.h:145
std::int32_t LastError()
Definition: socket.h:69
void SocketFinalize()
Definition: socket.h:163
std::int32_t ShutdownSocket(SocketT fd)
Definition: socket.h:122
collective::Result FailWithCode(std::string msg)
Definition: socket.h:78
int SocketT
Definition: socket.h:101
Core data structure for multi-target trees.
Definition: base.h:89
int SOCKET
Definition: poll_utils.h:40
#define __builtin_LINE()
Definition: result.h:57
#define __builtin_FILE()
Definition: result.h:56
#define INVALID_SOCKET
Definition: socket.h:102
#define xgboost_CHECK_SYS_CALL(exp, expected)
Definition: socket.h:106
Definition: string_view.h:16
An error type that's easier to handle than throwing dmlc exception. We can record and propagate the s...
Definition: result.h:67