src/Network/Socket.cc
changeset 378 5589abf5e61b
parent 187 f41f894213ca
child 380 d193dd1d8a7e
equal deleted inserted replaced
377:01d3c340b372 378:5589abf5e61b
     1 
     1 
     2 #include "Socket.hh"
     2 #include "Socket.hh"
       
     3 #include "../Engine.hh"
     3 
     4 
     4 #include <sstream>
     5 #include <sstream>
       
     6 
       
     7 NetworkSocket::NetworkSocket (int family, int socktype, int protocol) :
       
     8     fd(-1), family(family), socktype(socktype), protocol(protocol), bound(false)
       
     9 {
       
    10 
       
    11 }
       
    12         
       
    13 NetworkSocket::NetworkSocket (int fd) :
       
    14     fd(fd), family(0), socktype(0), protocol(0), bound(false)
       
    15 {
       
    16     
       
    17 }
       
    18         
       
    19 NetworkSocket::~NetworkSocket (void) {
       
    20     // close any remaining socket
       
    21     if (fd >= 0)
       
    22         force_close();
       
    23 }
       
    24         
       
    25 void NetworkSocket::lazy_socket (int family, int socktype, int protocol) {
       
    26     // if we already have a socket, exit
       
    27     if (fd >= 0)
       
    28         return;
       
    29     
       
    30     // check that we don't have conflicting family/type/protocol
       
    31     if (
       
    32         (this->family && family != this->family) || 
       
    33         (this->socktype && socktype != this->socktype) || 
       
    34         (this->protocol && protocol != this->protocol)
       
    35     )
       
    36         throw NetworkSocketError(*this, "socket.create", "family/socktype/protocol mismatch");
       
    37     
       
    38     // create the socket or fail
       
    39     if ((fd = ::socket(family, socktype, protocol)) < 0)
       
    40         throw NetworkSocketOSError(*this, "socket");
       
    41 
       
    42     // update our family/type/protocol
       
    43     this->family = family;
       
    44     this->socktype = socktype;
       
    45     this->protocol = protocol;
       
    46 }
       
    47         
       
    48 void NetworkSocket::force_close (void) {
       
    49     // use closesocket
       
    50     if (::closesocket(fd))
       
    51         Engine::log(WARN, "socket.force_close") << "error closing socket: " /* XXX: errno */;
       
    52     
       
    53     // invalidate fd
       
    54     fd = -1;
       
    55 }
       
    56 
       
    57 void NetworkSocket::bind (const NetworkAddress &addr) {
       
    58     // get our addrinfo
       
    59     addrinfo *r, *results = addr.get_addrinfo(family, socktype, protocol, AI_PASSIVE);
       
    60     
       
    61     // find the right address to bind to
       
    62     for (r = results; r; r = r->ai_next) {
       
    63         // create socket if needed, warn on errors
       
    64         try {
       
    65             lazy_socket(r->ai_family, r->ai_socktype, r->ai_protocol);
       
    66 
       
    67         } catch (NetworkSocketError &e) {
       
    68             Engine::log(WARN, "socket.bind") << "unable to create socket for " << r << ": " << e.what();
       
    69             continue;
       
    70         }
       
    71 
       
    72         // bind it, warn on errors
       
    73         if (::bind(fd, r->ai_addr, r->ai_addrlen)) {
       
    74             Engine::log(WARN, "socket.bind") << "unable to bind on " << r /* XXX: errno */ ;
       
    75             
       
    76             // close the bad socket
       
    77             force_close();
       
    78 
       
    79             continue;
       
    80 
       
    81         } else {
       
    82             // we have a bound socket, break
       
    83             break;
       
    84         }
       
    85     }
       
    86    
       
    87     // release our addrinfo
       
    88     freeaddrinfo(results);
       
    89 
       
    90     // if we failed to bind, r is a NULL pointer
       
    91     if (r == NULL)
       
    92         throw NetworkSocketError(*this, "bind", "unable to bind on any addresses");
       
    93 
       
    94     // mark ourselves as bound
       
    95     bound = true;
       
    96 }
       
    97         
       
    98 void NetworkSocket::listen (int backlog) {
       
    99     // just call listen
       
   100     if (::listen(fd, backlog))
       
   101         throw NetworkSocketOSError(*this, "listen");
       
   102 }
       
   103         
       
   104 NetworkAddress NetworkSocket::get_local_address (void) {
       
   105     sockaddr_storage addr;
       
   106     socklen_t addrlen = sizeof(addr);
       
   107 
       
   108     // do getsockname()
       
   109     if (::getsockname(fd, (sockaddr *) &addr, &addrlen))
       
   110         throw NetworkSocketOSError(*this, "getsockname");
       
   111 
       
   112     // return addr
       
   113     return NetworkAddress((sockaddr *) &addr, addrlen);
       
   114 }
       
   115 
       
   116 NetworkAddress NetworkSocket::get_remote_address (void) {
       
   117     sockaddr_storage addr;
       
   118     socklen_t addrlen = sizeof(addr);
       
   119 
       
   120     // do getpeername()
       
   121     if (::getpeername(fd, (sockaddr *) &addr, &addrlen))
       
   122         throw NetworkSocketOSError(*this, "getpeername");
       
   123 
       
   124     // return addr
       
   125     return NetworkAddress((sockaddr *) &addr, addrlen);
       
   126 }
       
   127         
       
   128 void NetworkSocket::set_nonblocking (bool nonblocking) {
       
   129     // XXX: linux-specific
       
   130     if (fcntl(fd, F_SETFL, O_NONBLOCK, nonblocking ? 1 : 0) == -1)
       
   131         throw NetworkSocketOSError(*this, "fcntl(F_SETFL, O_NONBLOCK)");
       
   132 }
       
   133  
       
   134 NetworkSocket* NetworkSocket::accept (NetworkAddress *src) {
       
   135     int new_fd;
       
   136     sockaddr_storage addr;
       
   137     socklen_t addrlen = sizeof(addr);
       
   138 
       
   139     // try and get the FD
       
   140     if ((new_fd = ::accept(fd, (sockaddr *) &addr, &addrlen)))
       
   141         throw NetworkSocketOSError(*this, "accept");
       
   142     
       
   143     // allocate new NetworkSocket for new_fd
       
   144     NetworkSocket *socket = new NetworkSocket(new_fd);
       
   145     
       
   146     // update src
       
   147     if (src)
       
   148         src->set_sockaddr((sockaddr *) &addr, addrlen);    
       
   149 
       
   150     // done
       
   151     return socket;
       
   152 }
       
   153         
       
   154 void NetworkSocket::connect (const NetworkAddress &addr) {
       
   155     // get our addrinfo
       
   156     addrinfo *r, *results = addr.get_addrinfo(family, socktype, protocol);
       
   157     
       
   158     // find the right address to bind to
       
   159     for (r = results; r; r = r->ai_next) {
       
   160         // create socket if needed, warn on errors
       
   161         try {
       
   162             lazy_socket(r->ai_family, r->ai_socktype, r->ai_protocol);
       
   163 
       
   164         } catch (NetworkSocketError &e) {
       
   165             Engine::log(WARN, "socket.connect") << "unable to create socket for " << r << ": " << e.what();
       
   166             continue;
       
   167         }
       
   168 
       
   169         // connect it, warn on errors
       
   170         if (::connect(fd, r->ai_addr, r->ai_addrlen)) {
       
   171             Engine::log(WARN, "socket.connect") << "unable to connect to " << r /* XXX: errno */ ;
       
   172             
       
   173             // close unless bound, to not keep invalid sockets hanging around
       
   174             if (!bound)
       
   175                 force_close();
       
   176 
       
   177             continue;
       
   178 
       
   179         } else {
       
   180             // we have a connected socket, break
       
   181             break;
       
   182         }
       
   183     }
       
   184    
       
   185     // release our addrinfo
       
   186     freeaddrinfo(results);
       
   187 
       
   188     // if we failed to connect, r is a NULL pointer
       
   189     if (r == NULL)
       
   190         throw NetworkSocketError(*this, "connect", "unable to connect to any addresses");
       
   191 }
       
   192         
       
   193 size_t NetworkSocket::send (const char *buf, size_t size, const NetworkAddress *dest) {
       
   194     ssize_t ret;
       
   195     
       
   196     // use send or sendto?
       
   197     if (dest) {
       
   198         const sockaddr *addr;
       
   199         socklen_t addr_len;
       
   200         
       
   201         // get destination address
       
   202         addr = dest->get_sockaddr(addr_len);
       
   203         
       
   204         // sendto()
       
   205         if ((ret = ::sendto(fd, buf, size, 0, addr, addr_len)) < 0 && errno != EAGAIN)
       
   206             throw NetworkSocketOSError(*this, "sendto");
       
   207 
       
   208     } else {
       
   209         // send()
       
   210         if ((ret = ::send(fd, buf, size, 0)) < 0 && errno != EAGAIN)
       
   211             throw NetworkSocketOSError(*this, "send");
       
   212 
       
   213     }
       
   214     
       
   215     // sanity-check
       
   216     if (ret == 0) {
       
   217         // XXX: not sure what this means...
       
   218         Engine::log(ERROR, "socket.send") << "send[to] returned zero, trying again...";
       
   219         return 0;
       
   220     }
       
   221     
       
   222     // EAGAIN?
       
   223     if (ret < 0)
       
   224         return 0;
       
   225     
       
   226     // return number of bytes sent
       
   227     return ret;
       
   228 }
       
   229 
       
   230 size_t NetworkSocket::recv (char *buf, size_t size, NetworkAddress *src) {
       
   231     ssize_t ret;
       
   232 
       
   233     // use recv or recvfrom?
       
   234     if (src) {
       
   235         sockaddr_storage addr;
       
   236         socklen_t addr_len = sizeof(addr);
       
   237 
       
   238         // recvfrom()
       
   239         if ((ret = ::recvfrom(fd, buf, size, 0, (sockaddr *) &addr, &addr_len)) < 0 && errno != EAGAIN)
       
   240             throw NetworkSocketOSError(*this, "recvfrom");
       
   241 
       
   242         // modify src...
       
   243         src->set_sockaddr((sockaddr *) &addr, addr_len);
       
   244 
       
   245     } else {
       
   246         // recv
       
   247         if ((ret = ::recv(fd, buf, size, 0)) < 0 && errno != EAGAIN)
       
   248             throw NetworkSocketOSError(*this, "recv");
       
   249 
       
   250     }
       
   251 
       
   252     // EOF?
       
   253     if (ret == 0)
       
   254         throw NetworkSocketEOFError(*this, "recv");
       
   255 
       
   256     // EAGAIN?
       
   257     if (ret < 0)
       
   258         return 0;
       
   259     
       
   260     // return number of bytes received
       
   261     return ret;
       
   262 }
       
   263         
       
   264 void NetworkSocket::close (void) {
       
   265     // use closesocket
       
   266     if (::closesocket(fd))
       
   267         throw NetworkSocketOSError(*this, "close");
       
   268     
       
   269     // invalidate fd
       
   270     fd = -1;
       
   271 }
     5 
   272 
     6 std::string NetworkSocketError::build_str (const NetworkSocket &socket, const char *op, const char *err) {
   273 std::string NetworkSocketError::build_str (const NetworkSocket &socket, const char *op, const char *err) {
     7     std::stringstream ss;
   274     std::stringstream ss;
     8 
   275 
     9     ss << "socket #" << socket.get_socket() << " " << op << ": " << err;
   276     ss << "socket #" << socket.get_socket() << " " << op << ": " << err;