#include "NetworkTCP.hh"
#include "Engine.hh"
#include <cstdlib>
#include <cassert>
NetworkBuffer::NetworkBuffer (NetworkSocket &socket, size_t size_hint) :
socket(socket), buf(0), size(0), offset(0) {
// allocate initial buffer
if ((buf = (char *) malloc(size_hint)) == NULL)
throw NetworkBufferError("malloc failed");
// remember size
size = size_hint;
}
NetworkBuffer::~NetworkBuffer (void) {
free(buf);
}
void NetworkBuffer::resize (size_t item_size) {
size_t new_size = size;
// grow new_size until item_size fits
while (offset + item_size > new_size)
new_size *= 2;
// grow if needed
if (new_size != size) {
// realloc buffer
if ((buf = (char *) realloc((void *) buf, new_size)) == NULL)
throw NetworkBufferError("realloc failed");
// update size
size = new_size;
} else if (new_size > (offset + item_size) * 4) {
// XXX: shrink?
}
}
void NetworkBuffer::trim (size_t prefix_size) {
// update offset
offset -= prefix_size;
// shift the buffer forwards from (buf + prefix) -> (buf), copying (old_offset - prefix) bytes
memmove(buf, buf + prefix_size, offset);
}
bool NetworkBuffer::try_read (size_t item_size) {
int ret;
size_t to_read = item_size;
// keept reads at at least NETWORK_CHUNK_SIZE bytes
if (to_read < NETWORK_TCP_CHUNK_SIZE)
to_read = NETWORK_TCP_CHUNK_SIZE;
// resize buffer if needed
resize(to_read);
// read once
try {
ret = socket.recv(buf + offset, to_read);
} catch (CL_Error &e) {
if (errno == EAGAIN)
return false;
else
throw NetworkSocketOSError(socket, "recv");
}
// handle EOF
if (ret == 0)
throw NetworkSocketEOFError(socket, "recv");
assert(ret >= 0);
// update offset
offset += ret;
// did we get enough?
if ((unsigned int) ret < item_size)
return false;
else
return true;
}
bool NetworkBuffer::peek_prefix (uint16_t &ref) {
if (offset < sizeof(uint16_t))
return false;
ref = ntohs(*((uint16_t *) (buf)));
return true;
}
bool NetworkBuffer::peek_prefix (uint32_t &ref) {
if (offset < sizeof(uint32_t))
return false;
ref = ntohl(*((uint32_t *) (buf)));
return true;
}
template <typename PrefixType> PrefixType NetworkBuffer::read_prefix (char *buf_ptr, size_t buf_max) {
PrefixType prefix = 0;
size_t missing = 0;
do {
// do we have the prefix?
if (peek_prefix(prefix)) {
// do we already have the payload?
if (offset >= sizeof(PrefixType) + prefix) {
break;
} else {
missing = (sizeof(PrefixType) + prefix) - offset;
}
} else {
missing = sizeof(PrefixType);
}
// sanity-check
assert(missing);
// try and read the missing data
if (try_read(missing) == false) {
// if unable to read what we need, return zero.
return 0;
}
// assess the situation again
} while (true);
// copy the data over, unless it's too large
if (prefix <= buf_max) {
memcpy(buf_ptr, buf, prefix);
// trim the bytes out
trim(sizeof(PrefixType) + prefix);
// return
return prefix;
} else {
// trim the bytes out
trim(sizeof(PrefixType) + prefix);
throw NetworkBufferError("recv prefix overflow");
}
}
void NetworkBuffer::push_write (char *buf_ptr, size_t buf_size) {
int ret;
// try and short-circuit writes unless we have already buffered data
if (offset == 0) {
try {
// attempt to send something
ret = socket.send(buf_ptr, buf_size);
} catch (CL_Error &e) {
// ignore EAGAIN, detect this by setting ret to -1
if (errno != EAGAIN)
throw NetworkSocketOSError(socket, "send");
ret = -1;
}
// if we managed to send something, adjust buf/size and buffer
if (ret > 0) {
// sanity-check
assert(buf_size >= (unsigned int) ret);
buf_ptr += ret;
buf_size -= ret;
// if that was all, we're done
if (buf_size == 0)
return;
}
}
// resize to fit buf_size more bytes
resize(buf_size);
// copy into our internal buffer
memcpy(buf + offset, buf_ptr, buf_size);
}
void NetworkBuffer::flush_write (void) {
int ret;
// ignore if we don't have any data buffered
if (offset == 0)
return;
// attempt to write as much as possible
try {
ret = socket.send(buf, offset);
} catch (CL_Error &e) {
// ignore EAGAIN and just return
if (errno == EAGAIN)
return;
else
throw NetworkSocketOSError(socket, "send");
}
// trim the buffer
trim(ret);
}
void NetworkBuffer::write_prefix (char *buf, uint16_t prefix) {
uint16_t nval = htons(prefix);
push_write((char*) &nval, sizeof(uint16_t));
push_write(buf, prefix);
}
void NetworkBuffer::write_prefix (char *buf, uint32_t prefix) {
uint32_t nval = htonl(prefix);
push_write((char*) &nval, sizeof(uint32_t));
push_write(buf, prefix);
}
NetworkTCPTransport::NetworkTCPTransport (NetworkSocket socket) :
socket(socket), in(socket, NETWORK_TCP_INITIAL_IN_BUF), out(socket, NETWORK_TCP_INITIAL_OUT_BUF) {
// connect signals
slots.connect(socket.sig_read_triggered(), this, &NetworkTCPTransport::on_read);
slots.connect(socket.sig_write_triggered(), this, &NetworkTCPTransport::on_write);
slots.connect(socket.sig_disconnected(), this, &NetworkTCPTransport::on_disconnected);
}
void NetworkTCPTransport::on_read (void) {
uint16_t prefix;
NetworkPacket packet;
// let the in stream read length-prefixed packets and pass them on to handle_packet
while ((prefix = in.read_prefix<uint16_t>(packet.get_buf(), packet.get_buf_size())) > 0) {
packet.set_data_size(prefix);
_sig_packet(packet);
}
}
void NetworkTCPTransport::on_write (void) {
// just flush the output buffer
out.flush_write();
}
void NetworkTCPTransport::on_disconnected (void) {
// pass right through
_sig_disconnect();
}
void NetworkTCPTransport::write_packet (const NetworkPacket &packet) {
uint16_t prefix = packet.get_data_size();
if (prefix != packet.get_data_size())
throw CL_Error("send prefix overflow");
try {
// just write to the output buffer
out.write_prefix((char *) packet.get_buf(), prefix);
} catch (Error &e) {
const char *err = e.what();
Engine::log(ERROR, "tcp.write_packet") << err;
throw;
}
}
NetworkTCPServer::NetworkTCPServer (const NetworkAddress &listen_addr) :
socket(CL_Socket::tcp, CL_Socket::ipv4) {
// bind
socket.bind(listen_addr);
// assign slots
slots.connect(socket.sig_read_triggered(), this, &NetworkTCPServer::on_accept);
// listen
socket.listen(NETWORK_LISTEN_BACKLOG);
// use nonblocking sockets
socket.set_nonblocking(true);
}
void NetworkTCPServer::on_accept (void) {
// accept a new socket
NetworkSocket client_sock = socket.accept();
// create a new NetworkTCPTransport
NetworkTCPTransport *client = buildTransport(client_sock);
// let our user handle it
_sig_client(client);
}
NetworkTCPTransport* NetworkTCPServer::buildTransport (CL_Socket &socket) {
return new NetworkTCPTransport(socket);
}
NetworkTCPClient::NetworkTCPClient (const NetworkAddress &connect_addr) :
NetworkTCPTransport(NetworkSocket(CL_Socket::tcp, CL_Socket::ipv4)) {
// connect
socket.connect(connect_addr);
// use nonblocking sockets
socket.set_nonblocking(true);
}