6 #if !defined(NOMINMAX) && defined(_WIN32)
17 #include <system_error>
20 #if !defined(xgboost_IS_MINGW)
22 #if defined(__MINGW32__)
23 #define xgboost_IS_MINGW 1
33 using in_port_t = std::uint16_t;
36 #pragma comment(lib, "Ws2_32.lib")
39 #if !defined(xgboost_IS_MINGW)
45 #include <arpa/inet.h>
47 #include <netinet/in.h>
48 #include <netinet/in.h>
49 #include <netinet/tcp.h>
50 #include <sys/socket.h>
53 #if defined(__sun) || defined(sun)
54 #include <sys/sockio.h>
60 #include "xgboost/logging.h"
63 #if !defined(HOST_NAME_MAX)
64 #define HOST_NAME_MAX 256
69 #if defined(xgboost_IS_MINGW)
71 inline void MingWError() { LOG(FATAL) <<
"Distributed training on mingw is not supported."; }
77 return WSAGetLastError();
84 #if defined(__GLIBC__)
86 std::int32_t line = __builtin_LINE(),
87 char const *file = __builtin_FILE()) {
88 auto err = std::error_code{errsv, std::system_category()};
90 << file <<
"(" << line <<
"): Failed to call `" << fn_name <<
"`: " << err.message()
95 auto err = std::error_code{errsv, std::system_category()};
96 LOG(FATAL) <<
"Failed to call `" << fn_name <<
"`: " << err.message() << std::endl;
106 #if !defined(xgboost_CHECK_SYS_CALL)
107 #define xgboost_CHECK_SYS_CALL(exp, expected) \
109 if (XGBOOST_EXPECT((exp) != (expected), false)) { \
110 ::xgboost::system::ThrowAtError(#exp); \
117 return closesocket(fd);
126 return errsv == WSAEWOULDBLOCK;
128 return errsv == EAGAIN || errsv == EWOULDBLOCK;
135 if (WSAStartup(MAKEWORD(2, 2), &wsa_data) == -1) {
138 if (LOBYTE(wsa_data.wVersion) != 2 || HIBYTE(wsa_data.wVersion) != 2) {
140 LOG(FATAL) <<
"Could not find a usable version of Winsock.dll";
151 #if defined(_WIN32) && defined(xgboost_IS_MINGW)
153 inline const char *inet_ntop(
int,
const void *,
char *, socklen_t) {
163 namespace collective {
184 in_port_t
Port()
const {
return ntohs(addr_.sin6_port); }
187 char buf[INET6_ADDRSTRLEN];
188 auto const *s = system::inet_ntop(
static_cast<std::int32_t
>(
SockDomain::kV6), &addr_.sin6_addr,
189 buf, INET6_ADDRSTRLEN);
195 sockaddr_in6
const &
Handle()
const {
return addr_; }
209 in_port_t
Port()
const {
return ntohs(addr_.sin_port); }
212 char buf[INET_ADDRSTRLEN];
213 auto const *s = system::inet_ntop(
static_cast<std::int32_t
>(
SockDomain::kV4), &addr_.sin_addr,
214 buf, INET_ADDRSTRLEN);
220 sockaddr_in
const &
Handle()
const {
return addr_; }
242 auto const &
V4()
const {
return v4_; }
243 auto const &
V6()
const {
return v6_; }
254 HandleT handle_{InvalidSocket()};
257 #if defined(__APPLE__)
261 constexpr
static HandleT InvalidSocket() {
return -1; }
271 auto ret_iafamily = [](std::int32_t domain) {
278 LOG(FATAL) <<
"Unknown IA family.";
285 WSAPROTOCOL_INFOA info;
286 socklen_t len =
sizeof(info);
288 getsockopt(handle_, SOL_SOCKET, SO_PROTOCOL_INFO,
reinterpret_cast<char *
>(&info), &len),
290 return ret_iafamily(info.iAddressFamily);
291 #elif defined(__APPLE__)
293 #elif defined(__unix__)
296 socklen_t len =
sizeof(domain);
298 getsockopt(handle_, SOL_SOCKET, SO_DOMAIN,
reinterpret_cast<char *
>(&domain), &len), 0);
299 return ret_iafamily(domain);
302 socklen_t sizeofsa =
sizeof(sa);
304 if (sizeofsa <
sizeof(uchar_t) * 2) {
305 return ret_iafamily(AF_INET);
307 return ret_iafamily(sa.sa_family);
310 LOG(FATAL) <<
"Unknown platform.";
311 return ret_iafamily(AF_INET);
315 bool IsClosed()
const {
return handle_ == InvalidSocket(); }
319 std::int32_t error = 0;
320 socklen_t len =
sizeof(error);
322 getsockopt(handle_, SOL_SOCKET, SO_ERROR,
reinterpret_cast<char *
>(&error), &len), 0);
329 if (err == EBADF || err == EINTR)
return true;
334 bool non_block{
true};
336 u_long mode = non_block ? 1 : 0;
339 std::int32_t flag = fcntl(handle_, F_GETFL, 0);
348 if (fcntl(handle_, F_SETFL, flag) == -1) {
355 std::int32_t keepalive = 1;
357 reinterpret_cast<char *
>(&keepalive),
sizeof(keepalive)),
362 std::int32_t tcp_no_delay = 1;
364 setsockopt(handle_, IPPROTO_TCP, TCP_NODELAY,
reinterpret_cast<char *
>(&tcp_no_delay),
365 sizeof(tcp_no_delay)),
373 HandleT newfd = accept(handle_,
nullptr,
nullptr);
374 if (newfd == InvalidSocket()) {
408 auto handle =
reinterpret_cast<sockaddr
const *
>(&addr.Handle());
410 bind(handle_, handle,
sizeof(std::remove_reference_t<decltype(addr.Handle())>)), 0);
412 sockaddr_in6 res_addr;
413 socklen_t addrlen =
sizeof(res_addr);
415 getsockname(handle_,
reinterpret_cast<sockaddr *
>(&res_addr), &addrlen), 0);
416 return ntohs(res_addr.sin6_port);
419 auto handle =
reinterpret_cast<sockaddr
const *
>(&addr.Handle());
421 bind(handle_, handle,
sizeof(std::remove_reference_t<decltype(addr.Handle())>)), 0);
423 sockaddr_in res_addr;
424 socklen_t addrlen =
sizeof(res_addr);
426 getsockname(handle_,
reinterpret_cast<sockaddr *
>(&res_addr), &addrlen), 0);
427 return ntohs(res_addr.sin_port);
433 auto SendAll(
void const *buf, std::size_t len) {
434 char const *_buf =
reinterpret_cast<const char *
>(buf);
435 std::size_t ndone = 0;
436 while (ndone < len) {
437 ssize_t ret = send(handle_, _buf, len - ndone, 0);
453 char *_buf =
reinterpret_cast<char *
>(buf);
454 std::size_t ndone = 0;
455 while (ndone < len) {
456 ssize_t ret = recv(handle_, _buf, len - ndone, MSG_WAITALL);
478 auto Send(
const void *buf_, std::size_t len, std::int32_t flags = 0) {
479 const char *buf =
reinterpret_cast<const char *
>(buf_);
480 return send(handle_, buf, len, flags);
489 auto Recv(
void *buf, std::size_t len, std::int32_t flags = 0) {
490 char *_buf =
reinterpret_cast<char *
>(buf);
491 return recv(handle_, _buf, len, flags);
500 std::size_t
Recv(std::string *p_str);
505 if (InvalidSocket() != handle_) {
507 handle_ = InvalidSocket();
514 #if defined(xgboost_IS_MINGW)
518 auto fd = socket(
static_cast<std::int32_t
>(domain), SOCK_STREAM, 0);
519 if (fd == InvalidSocket()) {
524 #if defined(__APPLE__)
525 socket.domain_ = domain;
549 #undef xgboost_CHECK_SYS_CALL
551 #if defined(xgboost_IS_MINGW)
552 #undef xgboost_IS_MINGW
Defines configuration macros and basic types for xgboost.
SockAddrV4(sockaddr_in addr)
Definition: socket.h:203
static SockAddrV4 InaddrAny()
in_port_t Port() const
Definition: socket.h:209
sockaddr_in const & Handle() const
Definition: socket.h:220
std::string Addr() const
Definition: socket.h:211
static SockAddrV4 Loopback()
SockAddrV4()
Definition: socket.h:204
static SockAddrV6 InaddrAny()
SockAddrV6()
Definition: socket.h:179
sockaddr_in6 const & Handle() const
Definition: socket.h:195
in_port_t Port() const
Definition: socket.h:184
SockAddrV6(sockaddr_in6 addr)
Definition: socket.h:178
std::string Addr() const
Definition: socket.h:186
static SockAddrV6 Loopback()
Address for TCP socket, can be either IPv4 or IPv6.
Definition: socket.h:226
bool IsV6() const
Definition: socket.h:240
auto const & V6() const
Definition: socket.h:243
bool IsV4() const
Definition: socket.h:239
auto Domain() const
Definition: socket.h:237
SockAddress(SockAddrV4 const &addr)
Definition: socket.h:235
auto const & V4() const
Definition: socket.h:242
SockAddress(SockAddrV6 const &addr)
Definition: socket.h:234
TCP socket for simple communication.
Definition: socket.h:249
HandleT const & Handle() const
Return the native socket file descriptor.
Definition: socket.h:397
void SetNoDelay()
Definition: socket.h:361
std::int32_t GetSockError() const
get last error code if any
Definition: socket.h:318
TCPSocket & operator=(TCPSocket const &that)=delete
auto SendAll(void const *buf, std::size_t len)
Send data, without error then all data should be sent.
Definition: socket.h:433
system::SocketT HandleT
Definition: socket.h:251
void SetNonBlock()
Definition: socket.h:333
void Close()
Close the socket, called automatically in destructor if the socket is not closed.
Definition: socket.h:504
void Listen(std::int32_t backlog=16)
Listen to incoming requests. Should be called after bind.
Definition: socket.h:401
TCPSocket(TCPSocket const &that)=delete
std::size_t Send(StringView str)
Send string, format is matched with the Python socket wrapper in RABIT.
TCPSocket & operator=(TCPSocket &&that)
Definition: socket.h:390
auto Domain() const -> SockDomain
Return the socket domain.
Definition: socket.h:270
static TCPSocket Create(SockDomain domain)
Create a TCP socket on specified domain.
Definition: socket.h:513
auto Recv(void *buf, std::size_t len, std::int32_t flags=0)
receive data using the socket
Definition: socket.h:489
auto RecvAll(void *buf, std::size_t len)
Receive data, without error then all data should be received.
Definition: socket.h:452
bool IsClosed() const
Definition: socket.h:315
TCPSocket Accept()
Accept new connection, returns a new TCP socket for the new connection.
Definition: socket.h:372
~TCPSocket()
Definition: socket.h:381
auto Send(const void *buf_, std::size_t len, std::int32_t flags=0)
Send data using the socket.
Definition: socket.h:478
in_port_t BindHost()
Bind socket to INADDR_ANY, return the port selected by the OS.
Definition: socket.h:405
void SetKeepAlive()
Definition: socket.h:354
std::size_t Recv(std::string *p_str)
Receive string, format is matched with the Python socket wrapper in RABIT.
bool BadSocket() const
check if anything bad happens
Definition: socket.h:326
TCPSocket(TCPSocket &&that) noexcept(true)
Definition: socket.h:388
void swap(xgboost::IntrusivePtr< T > &x, xgboost::IntrusivePtr< T > &y) noexcept
Definition: intrusive_ptr.h:209
std::error_code Connect(SockAddress const &addr, TCPSocket *out)
Connect to remote address, returns the error code if failed (no exception is raised so that we can re...
std::string GetHostName()
Get the local host name.
Definition: socket.h:541
SockAddress MakeSockAddress(StringView host, in_port_t port)
Parse host address and return a SockAddress instance. Supports IPv4 and IPv6 host.
SockDomain
Definition: socket.h:166
auto ThrowAtError(StringView fn_name, std::int32_t errsv=LastError())
Definition: socket.h:94
void SocketStartup()
Definition: socket.h:132
std::int32_t CloseSocket(SocketT fd)
Definition: socket.h:115
bool LastErrorWouldBlock()
Definition: socket.h:123
std::int32_t LastError()
Definition: socket.h:75
void SocketFinalize()
Definition: socket.h:145
int SocketT
Definition: socket.h:103
namespace of xgboost
Definition: base.h:90
#define xgboost_CHECK_SYS_CALL(exp, expected)
Definition: socket.h:107
#define HOST_NAME_MAX
Definition: socket.h:64
Definition: string_view.h:15