src/Network/Socket.cc
changeset 378 5589abf5e61b
parent 187 f41f894213ca
child 380 d193dd1d8a7e
--- a/src/Network/Socket.cc	Mon Dec 15 16:41:00 2008 +0000
+++ b/src/Network/Socket.cc	Mon Dec 15 23:56:42 2008 +0000
@@ -1,8 +1,275 @@
 
 #include "Socket.hh"
+#include "../Engine.hh"
 
 #include <sstream>
 
+NetworkSocket::NetworkSocket (int family, int socktype, int protocol) :
+    fd(-1), family(family), socktype(socktype), protocol(protocol), bound(false)
+{
+
+}
+        
+NetworkSocket::NetworkSocket (int fd) :
+    fd(fd), family(0), socktype(0), protocol(0), bound(false)
+{
+    
+}
+        
+NetworkSocket::~NetworkSocket (void) {
+    // close any remaining socket
+    if (fd >= 0)
+        force_close();
+}
+        
+void NetworkSocket::lazy_socket (int family, int socktype, int protocol) {
+    // if we already have a socket, exit
+    if (fd >= 0)
+        return;
+    
+    // check that we don't have conflicting family/type/protocol
+    if (
+        (this->family && family != this->family) || 
+        (this->socktype && socktype != this->socktype) || 
+        (this->protocol && protocol != this->protocol)
+    )
+        throw NetworkSocketError(*this, "socket.create", "family/socktype/protocol mismatch");
+    
+    // create the socket or fail
+    if ((fd = ::socket(family, socktype, protocol)) < 0)
+        throw NetworkSocketOSError(*this, "socket");
+
+    // update our family/type/protocol
+    this->family = family;
+    this->socktype = socktype;
+    this->protocol = protocol;
+}
+        
+void NetworkSocket::force_close (void) {
+    // use closesocket
+    if (::closesocket(fd))
+        Engine::log(WARN, "socket.force_close") << "error closing socket: " /* XXX: errno */;
+    
+    // invalidate fd
+    fd = -1;
+}
+
+void NetworkSocket::bind (const NetworkAddress &addr) {
+    // get our addrinfo
+    addrinfo *r, *results = addr.get_addrinfo(family, socktype, protocol, AI_PASSIVE);
+    
+    // find the right address to bind to
+    for (r = results; r; r = r->ai_next) {
+        // create socket if needed, warn on errors
+        try {
+            lazy_socket(r->ai_family, r->ai_socktype, r->ai_protocol);
+
+        } catch (NetworkSocketError &e) {
+            Engine::log(WARN, "socket.bind") << "unable to create socket for " << r << ": " << e.what();
+            continue;
+        }
+
+        // bind it, warn on errors
+        if (::bind(fd, r->ai_addr, r->ai_addrlen)) {
+            Engine::log(WARN, "socket.bind") << "unable to bind on " << r /* XXX: errno */ ;
+            
+            // close the bad socket
+            force_close();
+
+            continue;
+
+        } else {
+            // we have a bound socket, break
+            break;
+        }
+    }
+   
+    // release our addrinfo
+    freeaddrinfo(results);
+
+    // if we failed to bind, r is a NULL pointer
+    if (r == NULL)
+        throw NetworkSocketError(*this, "bind", "unable to bind on any addresses");
+
+    // mark ourselves as bound
+    bound = true;
+}
+        
+void NetworkSocket::listen (int backlog) {
+    // just call listen
+    if (::listen(fd, backlog))
+        throw NetworkSocketOSError(*this, "listen");
+}
+        
+NetworkAddress NetworkSocket::get_local_address (void) {
+    sockaddr_storage addr;
+    socklen_t addrlen = sizeof(addr);
+
+    // do getsockname()
+    if (::getsockname(fd, (sockaddr *) &addr, &addrlen))
+        throw NetworkSocketOSError(*this, "getsockname");
+
+    // return addr
+    return NetworkAddress((sockaddr *) &addr, addrlen);
+}
+
+NetworkAddress NetworkSocket::get_remote_address (void) {
+    sockaddr_storage addr;
+    socklen_t addrlen = sizeof(addr);
+
+    // do getpeername()
+    if (::getpeername(fd, (sockaddr *) &addr, &addrlen))
+        throw NetworkSocketOSError(*this, "getpeername");
+
+    // return addr
+    return NetworkAddress((sockaddr *) &addr, addrlen);
+}
+        
+void NetworkSocket::set_nonblocking (bool nonblocking) {
+    // XXX: linux-specific
+    if (fcntl(fd, F_SETFL, O_NONBLOCK, nonblocking ? 1 : 0) == -1)
+        throw NetworkSocketOSError(*this, "fcntl(F_SETFL, O_NONBLOCK)");
+}
+ 
+NetworkSocket* NetworkSocket::accept (NetworkAddress *src) {
+    int new_fd;
+    sockaddr_storage addr;
+    socklen_t addrlen = sizeof(addr);
+
+    // try and get the FD
+    if ((new_fd = ::accept(fd, (sockaddr *) &addr, &addrlen)))
+        throw NetworkSocketOSError(*this, "accept");
+    
+    // allocate new NetworkSocket for new_fd
+    NetworkSocket *socket = new NetworkSocket(new_fd);
+    
+    // update src
+    if (src)
+        src->set_sockaddr((sockaddr *) &addr, addrlen);    
+
+    // done
+    return socket;
+}
+        
+void NetworkSocket::connect (const NetworkAddress &addr) {
+    // get our addrinfo
+    addrinfo *r, *results = addr.get_addrinfo(family, socktype, protocol);
+    
+    // find the right address to bind to
+    for (r = results; r; r = r->ai_next) {
+        // create socket if needed, warn on errors
+        try {
+            lazy_socket(r->ai_family, r->ai_socktype, r->ai_protocol);
+
+        } catch (NetworkSocketError &e) {
+            Engine::log(WARN, "socket.connect") << "unable to create socket for " << r << ": " << e.what();
+            continue;
+        }
+
+        // connect it, warn on errors
+        if (::connect(fd, r->ai_addr, r->ai_addrlen)) {
+            Engine::log(WARN, "socket.connect") << "unable to connect to " << r /* XXX: errno */ ;
+            
+            // close unless bound, to not keep invalid sockets hanging around
+            if (!bound)
+                force_close();
+
+            continue;
+
+        } else {
+            // we have a connected socket, break
+            break;
+        }
+    }
+   
+    // release our addrinfo
+    freeaddrinfo(results);
+
+    // if we failed to connect, r is a NULL pointer
+    if (r == NULL)
+        throw NetworkSocketError(*this, "connect", "unable to connect to any addresses");
+}
+        
+size_t NetworkSocket::send (const char *buf, size_t size, const NetworkAddress *dest) {
+    ssize_t ret;
+    
+    // use send or sendto?
+    if (dest) {
+        const sockaddr *addr;
+        socklen_t addr_len;
+        
+        // get destination address
+        addr = dest->get_sockaddr(addr_len);
+        
+        // sendto()
+        if ((ret = ::sendto(fd, buf, size, 0, addr, addr_len)) < 0 && errno != EAGAIN)
+            throw NetworkSocketOSError(*this, "sendto");
+
+    } else {
+        // send()
+        if ((ret = ::send(fd, buf, size, 0)) < 0 && errno != EAGAIN)
+            throw NetworkSocketOSError(*this, "send");
+
+    }
+    
+    // sanity-check
+    if (ret == 0) {
+        // XXX: not sure what this means...
+        Engine::log(ERROR, "socket.send") << "send[to] returned zero, trying again...";
+        return 0;
+    }
+    
+    // EAGAIN?
+    if (ret < 0)
+        return 0;
+    
+    // return number of bytes sent
+    return ret;
+}
+
+size_t NetworkSocket::recv (char *buf, size_t size, NetworkAddress *src) {
+    ssize_t ret;
+
+    // use recv or recvfrom?
+    if (src) {
+        sockaddr_storage addr;
+        socklen_t addr_len = sizeof(addr);
+
+        // recvfrom()
+        if ((ret = ::recvfrom(fd, buf, size, 0, (sockaddr *) &addr, &addr_len)) < 0 && errno != EAGAIN)
+            throw NetworkSocketOSError(*this, "recvfrom");
+
+        // modify src...
+        src->set_sockaddr((sockaddr *) &addr, addr_len);
+
+    } else {
+        // recv
+        if ((ret = ::recv(fd, buf, size, 0)) < 0 && errno != EAGAIN)
+            throw NetworkSocketOSError(*this, "recv");
+
+    }
+
+    // EOF?
+    if (ret == 0)
+        throw NetworkSocketEOFError(*this, "recv");
+
+    // EAGAIN?
+    if (ret < 0)
+        return 0;
+    
+    // return number of bytes received
+    return ret;
+}
+        
+void NetworkSocket::close (void) {
+    // use closesocket
+    if (::closesocket(fd))
+        throw NetworkSocketOSError(*this, "close");
+    
+    // invalidate fd
+    fd = -1;
+}
+
 std::string NetworkSocketError::build_str (const NetworkSocket &socket, const char *op, const char *err) {
     std::stringstream ss;