6 #if !defined(NOMINMAX) && defined(_WIN32)
17 #include <system_error>
20 #if !defined(xgboost_IS_MINGW)
21 #define xgboost_IS_MINGW() defined(__MINGW32__)
29 using in_port_t = std::uint16_t;
32 #pragma comment(lib, "Ws2_32.lib")
35 #if !xgboost_IS_MINGW()
41 #include <arpa/inet.h>
43 #include <netinet/in.h>
44 #include <netinet/in.h>
45 #include <netinet/tcp.h>
46 #include <sys/socket.h>
49 #if defined(__sun) || defined(sun)
50 #include <sys/sockio.h>
56 #include "xgboost/logging.h"
59 #if !defined(HOST_NAME_MAX)
60 #define HOST_NAME_MAX 256
65 #if xgboost_IS_MINGW()
67 inline void MingWError() { LOG(FATAL) <<
"Distributed training on mingw is not supported."; }
73 return WSAGetLastError();
80 #if defined(__GLIBC__)
82 std::int32_t line = __builtin_LINE(),
83 char const *file = __builtin_FILE()) {
84 auto err = std::error_code{errsv, std::system_category()};
86 << file <<
"(" << line <<
"): Failed to call `" << fn_name <<
"`: " << err.message()
91 auto err = std::error_code{errsv, std::system_category()};
92 LOG(FATAL) <<
"Failed to call `" << fn_name <<
"`: " << err.message() << std::endl;
102 #if !defined(xgboost_CHECK_SYS_CALL)
103 #define xgboost_CHECK_SYS_CALL(exp, expected) \
105 if (XGBOOST_EXPECT((exp) != (expected), false)) { \
106 ::xgboost::system::ThrowAtError(#exp); \
113 return closesocket(fd);
122 return errsv == WSAEWOULDBLOCK;
124 return errsv == EAGAIN || errsv == EWOULDBLOCK;
131 if (WSAStartup(MAKEWORD(2, 2), &wsa_data) == -1) {
134 if (LOBYTE(wsa_data.wVersion) != 2 || HIBYTE(wsa_data.wVersion) != 2) {
136 LOG(FATAL) <<
"Could not find a usable version of Winsock.dll";
147 #if defined(_WIN32) && xgboost_IS_MINGW()
149 inline const char *inet_ntop(
int,
const void *,
char *, socklen_t) {
159 namespace collective {
180 in_port_t
Port()
const {
return ntohs(addr_.sin6_port); }
183 char buf[INET6_ADDRSTRLEN];
184 auto const *s = system::inet_ntop(
static_cast<std::int32_t
>(
SockDomain::kV6), &addr_.sin6_addr,
185 buf, INET6_ADDRSTRLEN);
191 sockaddr_in6
const &
Handle()
const {
return addr_; }
205 in_port_t
Port()
const {
return ntohs(addr_.sin_port); }
208 char buf[INET_ADDRSTRLEN];
209 auto const *s = system::inet_ntop(
static_cast<std::int32_t
>(
SockDomain::kV4), &addr_.sin_addr,
210 buf, INET_ADDRSTRLEN);
216 sockaddr_in
const &
Handle()
const {
return addr_; }
238 auto const &
V4()
const {
return v4_; }
239 auto const &
V6()
const {
return v6_; }
250 HandleT handle_{InvalidSocket()};
253 #if defined(__APPLE__)
257 constexpr
static HandleT InvalidSocket() {
return -1; }
267 auto ret_iafamily = [](std::int32_t domain) {
274 LOG(FATAL) <<
"Unknown IA family.";
281 WSAPROTOCOL_INFOA info;
282 socklen_t len =
sizeof(info);
284 getsockopt(handle_, SOL_SOCKET, SO_PROTOCOL_INFO,
reinterpret_cast<char *
>(&info), &len),
286 return ret_iafamily(info.iAddressFamily);
287 #elif defined(__APPLE__)
289 #elif defined(__unix__)
292 socklen_t len =
sizeof(domain);
294 getsockopt(handle_, SOL_SOCKET, SO_DOMAIN,
reinterpret_cast<char *
>(&domain), &len), 0);
295 return ret_iafamily(domain);
298 socklen_t sizeofsa =
sizeof(sa);
300 getsockname(handle_, &sa, &sizeofsa), 0);
301 if (sizeofsa <
sizeof(uchar_t)*2) {
302 return ret_iafamily(AF_INET);
304 return ret_iafamily(sa.sa_family);
307 LOG(FATAL) <<
"Unknown platform.";
308 return ret_iafamily(AF_INET);
312 bool IsClosed()
const {
return handle_ == InvalidSocket(); }
316 std::int32_t error = 0;
317 socklen_t len =
sizeof(error);
319 getsockopt(handle_, SOL_SOCKET, SO_ERROR,
reinterpret_cast<char *
>(&error), &len), 0);
326 if (err == EBADF || err == EINTR)
return true;
331 bool non_block{
true};
333 u_long mode = non_block ? 1 : 0;
336 std::int32_t flag = fcntl(handle_, F_GETFL, 0);
345 if (fcntl(handle_, F_SETFL, flag) == -1) {
352 std::int32_t keepalive = 1;
354 reinterpret_cast<char *
>(&keepalive),
sizeof(keepalive)),
359 std::int32_t tcp_no_delay = 1;
361 setsockopt(handle_, IPPROTO_TCP, TCP_NODELAY,
reinterpret_cast<char *
>(&tcp_no_delay),
362 sizeof(tcp_no_delay)),
370 HandleT newfd = accept(handle_,
nullptr,
nullptr);
371 if (newfd == InvalidSocket()) {
405 auto handle =
reinterpret_cast<sockaddr
const *
>(&addr.Handle());
407 bind(handle_, handle,
sizeof(std::remove_reference_t<decltype(addr.Handle())>)), 0);
409 sockaddr_in6 res_addr;
410 socklen_t addrlen =
sizeof(res_addr);
412 getsockname(handle_,
reinterpret_cast<sockaddr *
>(&res_addr), &addrlen), 0);
413 return ntohs(res_addr.sin6_port);
416 auto handle =
reinterpret_cast<sockaddr
const *
>(&addr.Handle());
418 bind(handle_, handle,
sizeof(std::remove_reference_t<decltype(addr.Handle())>)), 0);
420 sockaddr_in res_addr;
421 socklen_t addrlen =
sizeof(res_addr);
423 getsockname(handle_,
reinterpret_cast<sockaddr *
>(&res_addr), &addrlen), 0);
424 return ntohs(res_addr.sin_port);
430 auto SendAll(
void const *buf, std::size_t len) {
431 char const *_buf =
reinterpret_cast<const char *
>(buf);
432 std::size_t ndone = 0;
433 while (ndone < len) {
434 ssize_t ret = send(handle_, _buf, len - ndone, 0);
450 char *_buf =
reinterpret_cast<char *
>(buf);
451 std::size_t ndone = 0;
452 while (ndone < len) {
453 ssize_t ret = recv(handle_, _buf, len - ndone, MSG_WAITALL);
475 auto Send(
const void *buf_, std::size_t len, std::int32_t flags = 0) {
476 const char *buf =
reinterpret_cast<const char *
>(buf_);
477 return send(handle_, buf, len, flags);
486 auto Recv(
void *buf, std::size_t len, std::int32_t flags = 0) {
487 char *_buf =
reinterpret_cast<char *
>(buf);
488 return recv(handle_, _buf, len, flags);
497 std::size_t
Recv(std::string *p_str);
502 if (InvalidSocket() != handle_) {
504 handle_ = InvalidSocket();
511 #if xgboost_IS_MINGW()
515 auto fd = socket(
static_cast<std::int32_t
>(domain), SOCK_STREAM, 0);
516 if (fd == InvalidSocket()) {
521 #if defined(__APPLE__)
522 socket.domain_ = domain;
546 #undef xgboost_CHECK_SYS_CALL
547 #undef xgboost_IS_MINGW
defines configuration macros of xgboost.
SockAddrV4(sockaddr_in addr)
Definition: socket.h:199
static SockAddrV4 InaddrAny()
in_port_t Port() const
Definition: socket.h:205
sockaddr_in const & Handle() const
Definition: socket.h:216
std::string Addr() const
Definition: socket.h:207
static SockAddrV4 Loopback()
SockAddrV4()
Definition: socket.h:200
static SockAddrV6 InaddrAny()
SockAddrV6()
Definition: socket.h:175
sockaddr_in6 const & Handle() const
Definition: socket.h:191
in_port_t Port() const
Definition: socket.h:180
SockAddrV6(sockaddr_in6 addr)
Definition: socket.h:174
std::string Addr() const
Definition: socket.h:182
static SockAddrV6 Loopback()
Address for TCP socket, can be either IPv4 or IPv6.
Definition: socket.h:222
bool IsV6() const
Definition: socket.h:236
auto const & V6() const
Definition: socket.h:239
bool IsV4() const
Definition: socket.h:235
auto Domain() const
Definition: socket.h:233
SockAddress(SockAddrV4 const &addr)
Definition: socket.h:231
auto const & V4() const
Definition: socket.h:238
SockAddress(SockAddrV6 const &addr)
Definition: socket.h:230
TCP socket for simple communication.
Definition: socket.h:245
HandleT const & Handle() const
Return the native socket file descriptor.
Definition: socket.h:394
void SetNoDelay()
Definition: socket.h:358
std::int32_t GetSockError() const
get last error code if any
Definition: socket.h:315
TCPSocket & operator=(TCPSocket const &that)=delete
auto SendAll(void const *buf, std::size_t len)
Send data, without error then all data should be sent.
Definition: socket.h:430
system::SocketT HandleT
Definition: socket.h:247
void SetNonBlock()
Definition: socket.h:330
void Close()
Close the socket, called automatically in destructor if the socket is not closed.
Definition: socket.h:501
void Listen(std::int32_t backlog=16)
Listen to incoming requests. Should be called after bind.
Definition: socket.h:398
TCPSocket(TCPSocket const &that)=delete
std::size_t Send(StringView str)
Send string, format is matched with the Python socket wrapper in RABIT.
TCPSocket & operator=(TCPSocket &&that)
Definition: socket.h:387
auto Domain() const -> SockDomain
Return the socket domain.
Definition: socket.h:266
static TCPSocket Create(SockDomain domain)
Create a TCP socket on specified domain.
Definition: socket.h:510
auto Recv(void *buf, std::size_t len, std::int32_t flags=0)
receive data using the socket
Definition: socket.h:486
auto RecvAll(void *buf, std::size_t len)
Receive data, without error then all data should be received.
Definition: socket.h:449
bool IsClosed() const
Definition: socket.h:312
TCPSocket Accept()
Accept new connection, returns a new TCP socket for the new connection.
Definition: socket.h:369
~TCPSocket()
Definition: socket.h:378
auto Send(const void *buf_, std::size_t len, std::int32_t flags=0)
Send data using the socket.
Definition: socket.h:475
in_port_t BindHost()
Bind socket to INADDR_ANY, return the port selected by the OS.
Definition: socket.h:402
void SetKeepAlive()
Definition: socket.h:351
std::size_t Recv(std::string *p_str)
Receive string, format is matched with the Python socket wrapper in RABIT.
bool BadSocket() const
check if anything bad happens
Definition: socket.h:323
TCPSocket(TCPSocket &&that) noexcept(true)
Definition: socket.h:385
void swap(xgboost::IntrusivePtr< T > &x, xgboost::IntrusivePtr< T > &y) noexcept
Definition: intrusive_ptr.h:209
std::error_code Connect(SockAddress const &addr, TCPSocket *out)
Connect to remote address, returns the error code if failed (no exception is raised so that we can re...
std::string GetHostName()
Get the local host name.
Definition: socket.h:538
SockAddress MakeSockAddress(StringView host, in_port_t port)
Parse host address and return a SockAddress instance. Supports IPv4 and IPv6 host.
SockDomain
Definition: socket.h:162
auto ThrowAtError(StringView fn_name, std::int32_t errsv=LastError())
Definition: socket.h:90
void SocketStartup()
Definition: socket.h:128
std::int32_t CloseSocket(SocketT fd)
Definition: socket.h:111
bool LastErrorWouldBlock()
Definition: socket.h:119
std::int32_t LastError()
Definition: socket.h:71
void SocketFinalize()
Definition: socket.h:141
int SocketT
Definition: socket.h:99
namespace of xgboost
Definition: base.h:110
#define xgboost_CHECK_SYS_CALL(exp, expected)
Definition: socket.h:103
#define HOST_NAME_MAX
Definition: socket.h:60
Definition: string_view.h:15