#include "Socket.hh"
#include "../Engine.hh"
#include <cerrno>
#include <cstring>
#include <sstream>
static std::string dump_addrinfo (addrinfo *r) {
std::stringstream ss;
ss << "ai_family=" << r->ai_family << ", ai_canonname=" << r->ai_canonname << ": " << strerror(errno);
return ss.str();
}
static std::string dump_errno (void) {
return std::string(strerror(errno));
}
NetworkSocket::NetworkSocket (int family, int socktype, int protocol, NetworkReactor *reactor) :
sock_type(family, socktype, protocol), fd(-1), type(family, socktype, protocol), registered(0), reactor(reactor ? reactor : NetworkReactor::current)
{
reset();
}
NetworkSocket::NetworkSocket (int fd, socket_type type, NetworkReactor *reactor) :
sock_type(type), fd(fd), type(type), registered(0), reactor(reactor ? reactor : NetworkReactor::current)
{
reset();
}
NetworkSocket::~NetworkSocket (void) {
// close any remaining socket
if (fd >= 0)
force_close();
// unregister from reactor?
if (registered)
reactor->remove_socket(this);
}
void NetworkSocket::reset (void) {
bound = false;
want_read = false;
want_write = false;
}
void NetworkSocket::lazy_socket (int family, int socktype, int protocol) {
// if we already have a socket, exit
if (fd >= 0)
return;
// ignore if we've requested a specific sock_type
if (
(sock_type.family && family != sock_type.family) ||
(sock_type.socktype && socktype != sock_type.socktype) ||
(sock_type.protocol && protocol != sock_type.protocol)
)
throw NetworkSocketError(*this, "socket.create", "family/socktype/protocol mismatch");
// create the socket or fail
if ((fd = ::socket(family, socktype, protocol)) < 0)
throw NetworkSocketErrno(*this, "socket");
// update our family/type/protocol
type.family = family;
type.socktype = socktype;
type.protocol = protocol;
}
void NetworkSocket::force_close (void) {
// use closesocket
if (::closesocket(fd))
Engine::log(WARN, "socket.force_close") << "error closing socket: " << dump_errno();
// reset state
fd = -1;
reset();
}
void NetworkSocket::bind (const NetworkAddress &addr) {
// get our addrinfo
addrinfo *r, *results = addr.get_addrinfo(type.family, type.socktype, type.protocol, AI_PASSIVE);
// find the right address to bind to
for (r = results; r; r = r->ai_next) {
// create socket if needed, warn on errors
try {
lazy_socket(r->ai_family, r->ai_socktype, r->ai_protocol);
} catch (NetworkSocketError &e) {
Engine::log(WARN, "socket.bind") << "unable to create socket for " << dump_addrinfo(r) << ": " << e.what();
continue;
}
// bind it, warn on errors
if (::bind(fd, r->ai_addr, r->ai_addrlen)) {
Engine::log(WARN, "socket.bind") << "unable to bind on " << dump_addrinfo(r) << ": " << dump_errno();
// close the bad socket
force_close();
continue;
} else {
// we have a bound socket, break
break;
}
}
// release our addrinfo
freeaddrinfo(results);
// if we failed to bind, r is a NULL pointer
if (r == NULL)
throw NetworkSocketError(*this, "bind", "unable to bind on any addresses");
// mark ourselves as bound
bound = true;
}
void NetworkSocket::listen (int backlog) {
// just call listen
if (::listen(fd, backlog))
throw NetworkSocketErrno(*this, "listen");
}
NetworkAddress NetworkSocket::get_local_address (void) {
sockaddr_storage addr;
socklen_t addrlen = sizeof(addr);
// do getsockname()
if (::getsockname(fd, (sockaddr *) &addr, &addrlen))
throw NetworkSocketErrno(*this, "getsockname");
// return addr
return NetworkAddress((sockaddr *) &addr, addrlen);
}
NetworkAddress NetworkSocket::get_remote_address (void) {
sockaddr_storage addr;
socklen_t addrlen = sizeof(addr);
// do getpeername()
if (::getpeername(fd, (sockaddr *) &addr, &addrlen))
throw NetworkSocketErrno(*this, "getpeername");
// return addr
return NetworkAddress((sockaddr *) &addr, addrlen);
}
void NetworkSocket::set_nonblocking (bool nonblocking) {
// linux-specific
if (fcntl(fd, F_SETFL, O_NONBLOCK, nonblocking ? 1 : 0) == -1)
throw NetworkSocketErrno(*this, "fcntl(F_SETFL, O_NONBLOCK)");
}
NetworkSocket* NetworkSocket::accept (NetworkAddress *src) {
int new_fd;
sockaddr_storage addr;
socklen_t addrlen = sizeof(addr);
// try and get the FD
if ((new_fd = ::accept(fd, (sockaddr *) &addr, &addrlen)) < 0)
throw NetworkSocketErrno(*this, "accept");
// allocate new NetworkSocket for new_fd
NetworkSocket *socket = new NetworkSocket(new_fd, type, reactor);
// update src
if (src)
src->set_sockaddr((sockaddr *) &addr, addrlen);
// done
return socket;
}
void NetworkSocket::connect (const NetworkAddress &addr) {
// get our addrinfo
addrinfo *r, *results = addr.get_addrinfo(type.family, type.socktype, type.protocol);
// find the right address to bind to
for (r = results; r; r = r->ai_next) {
// create socket if needed, warn on errors
try {
lazy_socket(r->ai_family, r->ai_socktype, r->ai_protocol);
} catch (NetworkSocketError &e) {
Engine::log(WARN, "socket.connect") << "unable to create socket for " << dump_addrinfo(r) << ": " << e.what();
continue;
}
// connect it, warn on errors
if (::connect(fd, r->ai_addr, r->ai_addrlen)) {
Engine::log(WARN, "socket.connect") << "unable to connect to " << dump_addrinfo(r) << ": " << dump_errno();
// close unless bound, to not keep invalid sockets hanging around
if (!bound)
force_close();
continue;
} else {
// we have a connected socket, break
break;
}
}
// release our addrinfo
freeaddrinfo(results);
// if we failed to connect, r is a NULL pointer
if (r == NULL)
throw NetworkSocketError(*this, "connect", "unable to connect to any addresses");
}
size_t NetworkSocket::send (const char *buf, size_t size, const NetworkAddress *dest) {
ssize_t ret;
// use send or sendto?
if (dest) {
const sockaddr *addr;
socklen_t addr_len;
// get destination address
addr = dest->get_sockaddr(addr_len);
// sendto()
if ((ret = ::sendto(fd, buf, size, 0, addr, addr_len)) < 0 && errno != EAGAIN)
throw NetworkSocketErrno(*this, "sendto");
} else {
// send()
if ((ret = ::send(fd, buf, size, 0)) < 0 && errno != EAGAIN)
throw NetworkSocketErrno(*this, "send");
}
// sanity-check
if (ret == 0) {
// XXX: not sure what this means...
Engine::log(ERROR, "socket.send") << "send[to] returned zero, trying again...";
return 0;
}
// EAGAIN?
if (ret < 0) {
// set want_write so we get a sig_write
want_write = true;
return 0;
}
// return number of bytes sent
return ret;
}
size_t NetworkSocket::recv (char *buf, size_t size, NetworkAddress *src) {
ssize_t ret;
// use recv or recvfrom?
if (src) {
sockaddr_storage addr;
socklen_t addr_len = sizeof(addr);
// recvfrom()
if ((ret = ::recvfrom(fd, buf, size, 0, (sockaddr *) &addr, &addr_len)) < 0 && errno != EAGAIN)
throw NetworkSocketErrno(*this, "recvfrom");
// update source address if recvfrom suceeded
if (ret > 0) {
// modify src...
src->set_sockaddr((sockaddr *) &addr, addr_len);
}
} else {
// recv
if ((ret = ::recv(fd, buf, size, 0)) < 0 && errno != EAGAIN)
throw NetworkSocketErrno(*this, "recv");
}
// EOF?
if (ret == 0)
throw NetworkSocketEOFError(*this, "recv");
// EAGAIN?
if (ret < 0)
return 0;
// return number of bytes received
return ret;
}
void NetworkSocket::close (void) {
// use closesocket
if (::closesocket(fd))
throw NetworkSocketErrno(*this, "close");
// reset
fd = -1;
reset();
}
void NetworkSocket::register_poll (void) {
if (registered) return;
reactor->add_socket(this);
registered = true;
}
/*
* NetworkSocketError
*/
std::string NetworkSocketError::build_str (const NetworkSocket &socket, const char *op, const char *err) {
std::stringstream ss;
ss << "socket #" << socket.get_socket() << " " << op << ": " << err;
return ss.str();
}
NetworkSocketError::NetworkSocketError (const NetworkSocket &socket, const char *op, const char *err) :
Error(build_str(socket, op, err))
{
// nothing
}
NetworkSocketErrno::NetworkSocketErrno (const NetworkSocket &socket, const char *op) :
NetworkSocketError(socket, op, strerror(errno))
{
}