src/Network/Socket.cc
author Tero Marttila <terom@fixme.fi>
Fri, 11 Sep 2009 15:34:48 +0300
changeset 446 e411c0799fcc
parent 399 c7295b72731a
permissions -rw-r--r--
fix NetworkSocket to not forget want_* state if resetting between connect()'s

#include "Socket.hh"
#include "../Engine.hh"

#include <cerrno>
#include <cstring>
#include <sstream>

static std::string dump_addrinfo (addrinfo *r) {
    std::stringstream ss;

    ss << "ai_family=" << r->ai_family << ", ai_canonname=" << r->ai_canonname << ": " << strerror(errno);

    return ss.str();
}

static std::string dump_errno (void) {
    return std::string(strerror(errno));
}

NetworkSocket::NetworkSocket (int family, int socktype, int protocol, NetworkReactor *reactor) :
    sock_type(family, socktype, protocol), fd(-1), type(family, socktype, protocol),
    bound(false),
    registered(false), reactor(reactor ? reactor : NetworkReactor::current),
    want_read(false), want_write(false)
{

}
        
NetworkSocket::NetworkSocket (int fd, socket_type type, NetworkReactor *reactor) :
    sock_type(type), fd(fd), type(type), 
    bound(false),
    registered(false), reactor(reactor ? reactor : NetworkReactor::current),
    want_read(false), want_write(false)
{

}
        
NetworkSocket::~NetworkSocket (void) {
    // close any remaining socket
    if (fd >= 0)
        force_close();

    // unregister from reactor?
    if (registered)
        reactor->remove_socket(this);
}
        
void NetworkSocket::reset (void) {
    fd = -1;
    bound = false;
}

void NetworkSocket::lazy_socket (int family, int socktype, int protocol) {
    // if we already have a socket, good
    // XXX: should we check family/socktype/protocol against sock_type?
    if (fd >= 0)
        return;
    
    // ignore if we've requested a specific sock_type
    if (
        (sock_type.family && family != sock_type.family) || 
        (sock_type.socktype && socktype != sock_type.socktype) || 
        (sock_type.protocol && protocol != sock_type.protocol)
    )
        throw NetworkSocketError(*this, "socket.create", "family/socktype/protocol mismatch");
    
    // create the socket or fail
    if ((fd = ::socket(family, socktype, protocol)) < 0)
        throw NetworkSocketErrno(*this, "socket");

    // update our family/type/protocol
    type.family = family;
    type.socktype = socktype;
    type.protocol = protocol;
}
        
void NetworkSocket::force_close (void) {
    // use closesocket
    if (::closesocket(fd))
        Engine::log(WARN, "socket.force_close") << "error closing socket: " << dump_errno();
    
    // forget fd
    reset();
}

void NetworkSocket::bind (const NetworkEndpoint &addr) {
    // get our addrinfo
    addrinfo *r, *results = addr.get_addrinfo(type.family, type.socktype, type.protocol, AI_PASSIVE);
    
    // find the right address to bind to
    for (r = results; r; r = r->ai_next) {
        // create socket if needed, warn on errors
        try {
            lazy_socket(r->ai_family, r->ai_socktype, r->ai_protocol);

        } catch (NetworkSocketError &e) {
            Engine::log(WARN, "socket.bind") << "unable to create socket for " << dump_addrinfo(r) << ": " << e.what();
            continue;
        }

        // bind it, warn on errors
        if (::bind(fd, r->ai_addr, r->ai_addrlen)) {
            Engine::log(WARN, "socket.bind") << "unable to bind on " << addr << " (" << dump_addrinfo(r) << "): " << dump_errno();
            
            // close the bad socket
            force_close();

            continue;

        } else {
            // we have a bound socket, break
            break;
        }
    }
   
    // release our addrinfo
    addr.free_addrinfo(results);

    // if we failed to bind, r is a NULL pointer
    if (r == NULL)
        throw NetworkSocketError(*this, "bind", "unable to bind on any addresses");

    // mark ourselves as bound
    bound = true;
}
        
void NetworkSocket::listen (int backlog) {
    // just call listen
    if (::listen(fd, backlog))
        throw NetworkSocketErrno(*this, "listen");
}
        
NetworkAddress NetworkSocket::get_local_address (void) {
    NetworkAddress addr;

    // do getsockname()
    if (::getsockname(fd, addr.get_sockaddr(), addr.get_socklen_ptr()))
        throw NetworkSocketErrno(*this, "getsockname");

    // updated sockaddr
    addr.update();

    return addr;
}

NetworkAddress NetworkSocket::get_remote_address (void) {
    NetworkAddress addr;

    // do getpeername()
    if (::getpeername(fd, addr.get_sockaddr(), addr.get_socklen_ptr()))
        throw NetworkSocketErrno(*this, "getpeername");

    // updated sockaddr
    addr.update();

    return addr;
}
        
void NetworkSocket::set_nonblocking (bool nonblocking) {
    // linux-specific
    if (fcntl(fd, F_SETFL, O_NONBLOCK, nonblocking ? 1 : 0) == -1)
        throw NetworkSocketErrno(*this, "fcntl(F_SETFL, O_NONBLOCK)");
}
 
NetworkSocket* NetworkSocket::accept (NetworkAddress *src) {
    int new_fd;

    // try and get the FD
    // pass src sock* pointers if provided
    if ((new_fd = ::accept(
            fd, 
            src ? src->get_sockaddr() : NULL,
            src ? src->get_socklen_ptr() : NULL
    )) < 0)
        throw NetworkSocketErrno(*this, "accept");
    
    // allocate new NetworkSocket for new_fd
    NetworkSocket *socket = new NetworkSocket(new_fd, type, reactor);
    
    // update src
    if (src)
        src->update();

    return socket;
}
        
void NetworkSocket::connect (const NetworkEndpoint &addr) {
    // get our addrinfo
    addrinfo *r, *results = addr.get_addrinfo(type.family, type.socktype, type.protocol);
    
    // find the right address to bind to
    for (r = results; r; r = r->ai_next) {
        // create socket if needed, warn on errors
        try {
            lazy_socket(r->ai_family, r->ai_socktype, r->ai_protocol);

        } catch (NetworkSocketError &e) {
            Engine::log(WARN, "socket.connect") << "unable to create socket for " << dump_addrinfo(r) << ": " << e.what();
            continue;
        }

        // connect it, warn on errors
        if (::connect(fd, r->ai_addr, r->ai_addrlen)) {
            Engine::log(WARN, "socket.connect") << "unable to connect to " << addr << " (" << dump_addrinfo(r) << "): " << dump_errno();
            
            // close unless bound, to not keep invalid sockets hanging around
            if (!bound)
                force_close();

            continue;

        } else {
            // we have a connected socket, break
            break;
        }
    }
   
    // release our addrinfo
    addr.free_addrinfo(results);

    // if we failed to connect, r is a NULL pointer
    if (r == NULL)
        throw NetworkSocketError(*this, "connect", "unable to connect to any addresses");
}
        
size_t NetworkSocket::send (const char *buf, size_t size, const NetworkAddress *dest) {
    ssize_t ret;
    
    // use send or sendto?
    if (dest) {
        // sendto()
        if ((ret = ::sendto(fd, buf, size, 0, dest->get_sockaddr(), dest->get_socklen())) < 0 && errno != EAGAIN)
            throw NetworkSocketErrno(*this, "sendto");

    } else {
        // send()
        if ((ret = ::send(fd, buf, size, 0)) < 0 && errno != EAGAIN)
            throw NetworkSocketErrno(*this, "send");

    }
    
    // sanity-check
    if (ret == 0) {
        // XXX: not sure what this means...
        Engine::log(ERROR, "socket.send") << "send[to] returned zero, trying again...";
        return 0;
    }
    
    // EAGAIN? 
    if (ret < 0) {
        // set want_write so we get a sig_write
        // XXX: this is the job of the user
        // want_write = true;

        return 0;
    }
    
    // return number of bytes sent
    return ret;
}

size_t NetworkSocket::recv (char *buf, size_t size, NetworkAddress *src) {
    ssize_t ret;

    // use recv or recvfrom?
    if (src) {
        // recvfrom()
        if ((ret = ::recvfrom(fd, buf, size, 0, src->get_sockaddr(), src->get_socklen_ptr())) < 0 && errno != EAGAIN)
            throw NetworkSocketErrno(*this, "recvfrom");
        
        // update if valid
        if (ret > 0)
            src->update();

    } else {
        // recv
        if ((ret = ::recv(fd, buf, size, 0)) < 0 && errno != EAGAIN)
            throw NetworkSocketErrno(*this, "recv");

    }

    // EOF?
    if (ret == 0)
        throw NetworkSocketEOFError(*this, "recv");

    // EAGAIN?
    if (ret < 0)
        return 0;
    
    // return number of bytes received
    return ret;
}
        
void NetworkSocket::close (void) {
    // use closesocket
    if (::closesocket(fd))
        throw NetworkSocketErrno(*this, "close");
    
    // forget fd
    reset();
}

void NetworkSocket::register_poll (void) { 
    if (registered) return; 

    reactor->add_socket(this); 
    registered = true; 
}

/*
 * NetworkSocketError
 */
std::string NetworkSocketError::build_str (const NetworkSocket &socket, const char *op, const char *err) {
    std::stringstream ss;

    ss << "socket #" << socket.get_socket() << " " << op << ": " << err;

    return ss.str();
}

NetworkSocketError::NetworkSocketError (const NetworkSocket &socket, const char *op, const char *err) :
    Error(build_str(socket, op, err)) 
{
    // nothing
}

NetworkSocketErrno::NetworkSocketErrno (const NetworkSocket &socket, const char *op) :
    NetworkSocketError(socket, op, strerror(errno)) 
{ 

}