src/Network/Socket.cc
changeset 380 d193dd1d8a7e
parent 378 5589abf5e61b
child 381 9b35bc329d23
--- a/src/Network/Socket.cc	Tue Dec 16 20:30:35 2008 +0000
+++ b/src/Network/Socket.cc	Tue Dec 16 23:21:26 2008 +0000
@@ -2,61 +2,86 @@
 #include "Socket.hh"
 #include "../Engine.hh"
 
+#include <cerrno>
+#include <cstring>
 #include <sstream>
 
-NetworkSocket::NetworkSocket (int family, int socktype, int protocol) :
-    fd(-1), family(family), socktype(socktype), protocol(protocol), bound(false)
+static std::string dump_addrinfo (addrinfo *r) {
+    std::stringstream ss;
+
+    ss << "ai_family=" << r->ai_family << ", ai_canonname=" << r->ai_canonname << ": " << strerror(errno);
+
+    return ss.str();
+}
+
+static std::string dump_errno (void) {
+    return std::string(strerror(errno));
+}
+
+NetworkSocket::NetworkSocket (int family, int socktype, int protocol, NetworkReactor *reactor) :
+    sock_type(family, socktype, protocol), fd(-1), type(family, socktype, protocol), registered(0), reactor(reactor ? reactor : NetworkReactor::current)
 {
-
+    reset();
 }
         
-NetworkSocket::NetworkSocket (int fd) :
-    fd(fd), family(0), socktype(0), protocol(0), bound(false)
+NetworkSocket::NetworkSocket (int fd, socket_type type, NetworkReactor *reactor) :
+    sock_type(type), fd(fd), type(type), registered(0), reactor(reactor ? reactor : NetworkReactor::current)
 {
-    
+    reset();
 }
         
 NetworkSocket::~NetworkSocket (void) {
     // close any remaining socket
     if (fd >= 0)
         force_close();
+
+    // unregister from reactor?
+    if (registered)
+        reactor->remove_socket(this);
 }
         
+void NetworkSocket::reset (void) {
+    bound = false;
+    want_read = false;
+    want_write = false;
+}
+
 void NetworkSocket::lazy_socket (int family, int socktype, int protocol) {
     // if we already have a socket, exit
     if (fd >= 0)
         return;
     
-    // check that we don't have conflicting family/type/protocol
+    // ignore if we've requested a specific sock_type
     if (
-        (this->family && family != this->family) || 
-        (this->socktype && socktype != this->socktype) || 
-        (this->protocol && protocol != this->protocol)
+        (sock_type.family && family != sock_type.family) || 
+        (sock_type.socktype && socktype != sock_type.socktype) || 
+        (sock_type.protocol && protocol != sock_type.protocol)
     )
         throw NetworkSocketError(*this, "socket.create", "family/socktype/protocol mismatch");
     
     // create the socket or fail
     if ((fd = ::socket(family, socktype, protocol)) < 0)
-        throw NetworkSocketOSError(*this, "socket");
+        throw NetworkSocketErrno(*this, "socket");
 
     // update our family/type/protocol
-    this->family = family;
-    this->socktype = socktype;
-    this->protocol = protocol;
+    type.family = family;
+    type.socktype = socktype;
+    type.protocol = protocol;
 }
         
 void NetworkSocket::force_close (void) {
     // use closesocket
     if (::closesocket(fd))
-        Engine::log(WARN, "socket.force_close") << "error closing socket: " /* XXX: errno */;
+        Engine::log(WARN, "socket.force_close") << "error closing socket: " << dump_errno();
     
-    // invalidate fd
+    // reset state
     fd = -1;
+    reset();
 }
 
 void NetworkSocket::bind (const NetworkAddress &addr) {
     // get our addrinfo
-    addrinfo *r, *results = addr.get_addrinfo(family, socktype, protocol, AI_PASSIVE);
+    addrinfo *r, *results = addr.get_addrinfo(type.family, type.socktype, type.protocol, AI_PASSIVE);
     
     // find the right address to bind to
     for (r = results; r; r = r->ai_next) {
@@ -65,13 +90,13 @@
             lazy_socket(r->ai_family, r->ai_socktype, r->ai_protocol);
 
         } catch (NetworkSocketError &e) {
-            Engine::log(WARN, "socket.bind") << "unable to create socket for " << r << ": " << e.what();
+            Engine::log(WARN, "socket.bind") << "unable to create socket for " << dump_addrinfo(r) << ": " << e.what();
             continue;
         }
 
         // bind it, warn on errors
         if (::bind(fd, r->ai_addr, r->ai_addrlen)) {
-            Engine::log(WARN, "socket.bind") << "unable to bind on " << r /* XXX: errno */ ;
+            Engine::log(WARN, "socket.bind") << "unable to bind on " << dump_addrinfo(r) << ": " << dump_errno();
             
             // close the bad socket
             force_close();
@@ -98,7 +123,7 @@
 void NetworkSocket::listen (int backlog) {
     // just call listen
     if (::listen(fd, backlog))
-        throw NetworkSocketOSError(*this, "listen");
+        throw NetworkSocketErrno(*this, "listen");
 }
         
 NetworkAddress NetworkSocket::get_local_address (void) {
@@ -107,7 +132,7 @@
 
     // do getsockname()
     if (::getsockname(fd, (sockaddr *) &addr, &addrlen))
-        throw NetworkSocketOSError(*this, "getsockname");
+        throw NetworkSocketErrno(*this, "getsockname");
 
     // return addr
     return NetworkAddress((sockaddr *) &addr, addrlen);
@@ -119,16 +144,16 @@
 
     // do getpeername()
     if (::getpeername(fd, (sockaddr *) &addr, &addrlen))
-        throw NetworkSocketOSError(*this, "getpeername");
+        throw NetworkSocketErrno(*this, "getpeername");
 
     // return addr
     return NetworkAddress((sockaddr *) &addr, addrlen);
 }
         
 void NetworkSocket::set_nonblocking (bool nonblocking) {
-    // XXX: linux-specific
+    // linux-specific
     if (fcntl(fd, F_SETFL, O_NONBLOCK, nonblocking ? 1 : 0) == -1)
-        throw NetworkSocketOSError(*this, "fcntl(F_SETFL, O_NONBLOCK)");
+        throw NetworkSocketErrno(*this, "fcntl(F_SETFL, O_NONBLOCK)");
 }
  
 NetworkSocket* NetworkSocket::accept (NetworkAddress *src) {
@@ -137,11 +162,11 @@
     socklen_t addrlen = sizeof(addr);
 
     // try and get the FD
-    if ((new_fd = ::accept(fd, (sockaddr *) &addr, &addrlen)))
-        throw NetworkSocketOSError(*this, "accept");
+    if ((new_fd = ::accept(fd, (sockaddr *) &addr, &addrlen)) < 0)
+        throw NetworkSocketErrno(*this, "accept");
     
     // allocate new NetworkSocket for new_fd
-    NetworkSocket *socket = new NetworkSocket(new_fd);
+    NetworkSocket *socket = new NetworkSocket(new_fd, type, reactor);
     
     // update src
     if (src)
@@ -153,7 +178,7 @@
         
 void NetworkSocket::connect (const NetworkAddress &addr) {
     // get our addrinfo
-    addrinfo *r, *results = addr.get_addrinfo(family, socktype, protocol);
+    addrinfo *r, *results = addr.get_addrinfo(type.family, type.socktype, type.protocol);
     
     // find the right address to bind to
     for (r = results; r; r = r->ai_next) {
@@ -162,13 +187,13 @@
             lazy_socket(r->ai_family, r->ai_socktype, r->ai_protocol);
 
         } catch (NetworkSocketError &e) {
-            Engine::log(WARN, "socket.connect") << "unable to create socket for " << r << ": " << e.what();
+            Engine::log(WARN, "socket.connect") << "unable to create socket for " << dump_addrinfo(r) << ": " << e.what();
             continue;
         }
 
         // connect it, warn on errors
         if (::connect(fd, r->ai_addr, r->ai_addrlen)) {
-            Engine::log(WARN, "socket.connect") << "unable to connect to " << r /* XXX: errno */ ;
+            Engine::log(WARN, "socket.connect") << "unable to connect to " << dump_addrinfo(r) << ": " << dump_errno();
             
             // close unless bound, to not keep invalid sockets hanging around
             if (!bound)
@@ -203,12 +228,12 @@
         
         // sendto()
         if ((ret = ::sendto(fd, buf, size, 0, addr, addr_len)) < 0 && errno != EAGAIN)
-            throw NetworkSocketOSError(*this, "sendto");
+            throw NetworkSocketErrno(*this, "sendto");
 
     } else {
         // send()
         if ((ret = ::send(fd, buf, size, 0)) < 0 && errno != EAGAIN)
-            throw NetworkSocketOSError(*this, "send");
+            throw NetworkSocketErrno(*this, "send");
 
     }
     
@@ -219,9 +244,13 @@
         return 0;
     }
     
-    // EAGAIN?
-    if (ret < 0)
+    // EAGAIN? 
+    if (ret < 0) {
+        // set want_write so we get a sig_write
+        want_write = true;
+
         return 0;
+    }
     
     // return number of bytes sent
     return ret;
@@ -237,15 +266,18 @@
 
         // recvfrom()
         if ((ret = ::recvfrom(fd, buf, size, 0, (sockaddr *) &addr, &addr_len)) < 0 && errno != EAGAIN)
-            throw NetworkSocketOSError(*this, "recvfrom");
+            throw NetworkSocketErrno(*this, "recvfrom");
 
-        // modify src...
-        src->set_sockaddr((sockaddr *) &addr, addr_len);
+        // update source address if recvfrom suceeded
+        if (ret > 0) {
+            // modify src...
+            src->set_sockaddr((sockaddr *) &addr, addr_len);
+        }
 
     } else {
         // recv
         if ((ret = ::recv(fd, buf, size, 0)) < 0 && errno != EAGAIN)
-            throw NetworkSocketOSError(*this, "recv");
+            throw NetworkSocketErrno(*this, "recv");
 
     }
 
@@ -264,12 +296,23 @@
 void NetworkSocket::close (void) {
     // use closesocket
     if (::closesocket(fd))
-        throw NetworkSocketOSError(*this, "close");
+        throw NetworkSocketErrno(*this, "close");
     
-    // invalidate fd
+    // reset
     fd = -1;
+    reset();
 }
 
+void NetworkSocket::register_poll (void) { 
+    if (registered) return; 
+
+    reactor->add_socket(this); 
+    registered = true; 
+}
+
+/*
+ * NetworkSocketError
+ */
 std::string NetworkSocketError::build_str (const NetworkSocket &socket, const char *op, const char *err) {
     std::stringstream ss;
 
@@ -279,8 +322,14 @@
 }
 
 NetworkSocketError::NetworkSocketError (const NetworkSocket &socket, const char *op, const char *err) :
-    Error(build_str(socket, op, err)) {
-    
+    Error(build_str(socket, op, err)) 
+{
     // nothing
 }
 
+NetworkSocketErrno::NetworkSocketErrno (const NetworkSocket &socket, const char *op) :
+    NetworkSocketError(socket, op, strerror(errno)) 
+{ 
+
+}
+