src/proto2/NetworkTCP.cc
changeset 89 825c4613e087
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/src/proto2/NetworkTCP.cc	Thu Nov 20 23:51:46 2008 +0000
@@ -0,0 +1,325 @@
+
+#include "NetworkTCP.hh"
+#include "Engine.hh"
+
+#include <cstdlib>
+#include <cassert>
+
+NetworkBuffer::NetworkBuffer (NetworkSocket &socket, size_t size_hint) :
+    socket(socket), buf(0), size(0), offset(0) {
+    
+    // allocate initial buffer
+    if ((buf = (char *) malloc(size_hint)) == NULL)
+       throw NetworkBufferError("malloc failed");
+    
+    // remember size
+    size = size_hint;
+}
+        
+NetworkBuffer::~NetworkBuffer (void) {
+    free(buf);
+}
+
+void NetworkBuffer::resize (size_t item_size) {
+    size_t new_size = size;
+
+    // grow new_size until item_size fits
+    while (offset + item_size > new_size)
+        new_size *= 2;
+    
+    // grow if needed
+    if (new_size != size) {
+        // realloc buffer
+        if ((buf = (char *) realloc((void *) buf, new_size)) == NULL)
+            throw NetworkBufferError("realloc failed");
+
+        // update size
+        size = new_size;
+
+    } else if (new_size > (offset + item_size) * 4) {
+        // XXX: shrink?
+    }
+}
+        
+void NetworkBuffer::trim (size_t prefix_size) {
+    // update offset
+    offset -= prefix_size;
+
+    // shift the buffer forwards from (buf + prefix) -> (buf), copying (old_offset - prefix) bytes
+    memmove(buf, buf + prefix_size, offset);
+}
+     
+bool NetworkBuffer::try_read (size_t item_size) {
+    int ret;
+    size_t to_read = item_size;
+
+    // keept reads at at least NETWORK_CHUNK_SIZE bytes
+    if (to_read < NETWORK_TCP_CHUNK_SIZE)
+        to_read = NETWORK_TCP_CHUNK_SIZE;
+
+    // resize buffer if needed
+    resize(to_read);
+
+    // read once
+    try {
+        ret = socket.recv(buf + offset, to_read);
+
+    } catch (CL_Error &e) {
+        if (errno == EAGAIN)
+            return false;
+
+        else
+            throw NetworkSocketOSError(socket, "recv");
+    }
+    
+    // handle EOF
+    if (ret == 0)
+        throw NetworkSocketEOFError(socket, "recv");
+
+    assert(ret >= 0);
+
+    // update offset
+    offset += ret;
+
+    // did we get enough?
+    if ((unsigned int) ret < item_size)
+        return false;
+    else
+        return true;
+} 
+        
+bool NetworkBuffer::peek_prefix (uint16_t &ref) {
+    if (offset < sizeof(uint16_t))
+        return false;
+
+    ref = ntohs(*((uint16_t *) (buf)));
+
+    return true;
+}
+    
+bool NetworkBuffer::peek_prefix (uint32_t &ref) {
+    if (offset < sizeof(uint32_t))
+        return false;
+
+    ref = ntohl(*((uint32_t *) (buf)));
+
+    return true;
+}
+
+template <typename PrefixType> PrefixType NetworkBuffer::read_prefix (char *buf_ptr, size_t buf_max) {
+    PrefixType prefix = 0;
+    size_t missing = 0;
+    
+    do {    
+        // do we have the prefix?
+        if (peek_prefix(prefix)) {
+            // do we already have the payload?
+            if (offset >= sizeof(PrefixType) + prefix) {
+                break;
+
+            } else {
+                missing = (sizeof(PrefixType) + prefix) - offset;
+            }
+
+        } else {
+            missing = sizeof(PrefixType);
+        }
+
+        // sanity-check
+        assert(missing);
+        
+        // try and read the missing data
+        if (try_read(missing) == false) {
+            // if unable to read what we need, return zero.
+            return 0;
+        }
+        
+        // assess the situation again
+    } while (true);
+    
+    // copy the data over, unless it's too large
+    if (prefix <= buf_max) {
+        // ...don't copy the prefix, though
+        memcpy(buf_ptr, buf + sizeof(PrefixType), prefix);
+    
+        // trim the bytes out
+        trim(sizeof(PrefixType) + prefix);
+        
+        // return
+        return prefix;
+
+    } else {
+        // trim the bytes out
+        trim(sizeof(PrefixType) + prefix);
+        
+        throw NetworkBufferError("recv prefix overflow");   
+    }
+}
+   
+void NetworkBuffer::push_write (char *buf_ptr, size_t buf_size) {
+    int ret;
+
+    // try and short-circuit writes unless we have already buffered data
+    if (offset == 0) {
+        try {
+            // attempt to send something
+            ret = socket.send(buf_ptr, buf_size);
+
+        } catch (CL_Error &e) {
+            // ignore EAGAIN, detect this by setting ret to -1
+            if (errno != EAGAIN)
+                throw NetworkSocketOSError(socket, "send");
+
+            ret = -1;
+        }
+        
+        // if we managed to send something, adjust buf/size and buffer
+        if (ret > 0) {
+            // sanity-check
+            assert(buf_size >= (unsigned int) ret);
+
+            buf_ptr += ret;
+            buf_size -= ret;
+
+            // if that was all, we're done
+            if (buf_size == 0)
+                return;
+        }
+    }
+    
+    // resize to fit buf_size more bytes
+    resize(buf_size);
+    
+    // copy into our internal buffer
+    memcpy(buf + offset, buf_ptr, buf_size);
+}
+        
+void NetworkBuffer::flush_write (void) {
+    int ret;
+
+    // ignore if we don't have any data buffered
+    if (offset == 0)
+        return;
+    
+    // attempt to write as much as possible
+    try {
+        ret = socket.send(buf, offset);
+
+    } catch (CL_Error &e) {
+        // ignore EAGAIN and just return
+        if (errno == EAGAIN)
+            return;
+
+        else
+            throw NetworkSocketOSError(socket, "send");
+    }
+
+    // trim the buffer
+    trim(ret);
+}
+        
+void NetworkBuffer::write_prefix (char *buf, uint16_t prefix) {
+    uint16_t nval = htons(prefix);
+
+    push_write((char*) &nval, sizeof(uint16_t)); 
+    push_write(buf, prefix);
+}
+
+void NetworkBuffer::write_prefix (char *buf, uint32_t prefix) {
+    uint32_t nval = htonl(prefix);
+
+    push_write((char*) &nval, sizeof(uint32_t)); 
+    push_write(buf, prefix);
+}
+
+NetworkTCPTransport::NetworkTCPTransport (NetworkSocket socket) :
+    socket(socket), in(socket, NETWORK_TCP_INITIAL_IN_BUF), out(socket, NETWORK_TCP_INITIAL_OUT_BUF) {
+    
+    // connect signals
+    slots.connect(socket.sig_read_triggered(), this, &NetworkTCPTransport::on_read);
+    slots.connect(socket.sig_write_triggered(), this, &NetworkTCPTransport::on_write);
+    slots.connect(socket.sig_disconnected(), this, &NetworkTCPTransport::on_disconnected);
+}
+
+
+void NetworkTCPTransport::on_read (void) {
+    uint16_t prefix;
+    NetworkPacket packet;
+    
+    // let the in stream read length-prefixed packets and pass them on to handle_packet
+    while ((prefix = in.read_prefix<uint16_t>(packet.get_buf(), packet.get_buf_size())) > 0) {
+        packet.set_data_size(prefix);
+        _sig_packet(packet);
+    }
+}
+
+void NetworkTCPTransport::on_write (void) {
+    // just flush the output buffer
+    out.flush_write();
+}
+
+void NetworkTCPTransport::on_disconnected (void) {
+    // pass right through
+    _sig_disconnect();
+}
+        
+void NetworkTCPTransport::write_packet (const NetworkPacket &packet) {
+    uint16_t prefix = packet.get_data_size();
+    
+    if (prefix != packet.get_data_size())
+        throw CL_Error("send prefix overflow");
+    
+    try {
+        // just write to the output buffer
+        out.write_prefix((char *) packet.get_buf(), prefix);
+
+    } catch (Error &e) {
+        const char *err = e.what();
+
+        Engine::log(ERROR, "tcp.write_packet") << err;
+        
+        throw;    
+    }
+}
+
+NetworkTCPServer::NetworkTCPServer (const NetworkAddress &listen_addr) :
+    socket(CL_Socket::tcp, CL_Socket::ipv4) {
+    
+    // bind
+    socket.bind(listen_addr);
+
+    // assign slots
+    slots.connect(socket.sig_read_triggered(), this, &NetworkTCPServer::on_accept);
+
+    // listen
+    socket.listen(NETWORK_LISTEN_BACKLOG);
+    
+    // use nonblocking sockets
+    socket.set_nonblocking(true);
+}
+
+
+void NetworkTCPServer::on_accept (void) {
+    // accept a new socket
+    NetworkSocket client_sock = socket.accept();
+
+    // create a new NetworkTCPTransport
+    NetworkTCPTransport *client = buildTransport(client_sock);
+        
+    // let our user handle it
+    _sig_client(client);
+}
+        
+NetworkTCPTransport* NetworkTCPServer::buildTransport (CL_Socket &socket) {
+    return new NetworkTCPTransport(socket);
+}
+        
+NetworkTCPClient::NetworkTCPClient (const NetworkAddress &connect_addr) :
+    NetworkTCPTransport(NetworkSocket(CL_Socket::tcp, CL_Socket::ipv4)) {
+
+    // connect
+    socket.connect(connect_addr);
+    
+    // use nonblocking sockets
+    socket.set_nonblocking(true);
+}