#include "msg_proto.h"
#include <string.h>
#include <stdint.h>
#include <arpa/inet.h>
/**
* I/O buffer
*/
struct msg_buf {
/** Buffer base pointer */
char *base;
/** Size of the buffer */
size_t size;
/** Current read/write offset */
size_t off;
};
/**
* The minimum size used for any msg_buf::size related operation.
*/
#define MSG_BUF_MIN_SIZE 1024
/**
* Growth rate for size
*/
#define MSG_BUF_GROW_RATE 2
/**
* Initialize a message buffer at the given initial size
*/
err_t msg_buf_init (struct msg_buf *buf, size_t hint)
{
// apply minimum size
if (hint < MSG_BUF_MIN_SIZE)
hint = MSG_BUF_MIN_SIZE;
// allocate the initial buffer
if ((buf->base = malloc(hint)) == NULL)
return ERR_MEM;
// set fields
buf->size = hint;
buf->off = 0;
// ok
return SUCCESS;
}
/**
* Grow the buffer if needed to fit the given capacity.
*/
err_t msg_buf_grow (struct msg_buf *buf, size_t size)
{
char *tmp = buf->base;
if (buf->size >= size)
// nothing to do
return SUCCESS;
// calculate new size
while (buf->size < size)
buf->size *= MSG_BUF_GROW_RATE;
// resize
if ((buf->base = realloc(buf->base, buf->size)) == NULL) {
buf->base = tmp;
return ERR_MEM;
}
// ok
return SUCCESS;
}
/**
* Drain \a len bytes off the head of the buffer
*/
err_t msg_buf_drain (struct msg_buf *buf, size_t len)
{
// simple memmove
memmove(buf->base, buf->base + len, buf->off - len);
// update offfset
buf->off -= len;
// ok
return SUCCESS;
}
/**
* Read into the buffer from a transport_t.
*
* This will attempt to read \a len bytes onto the end of the buffer, growing it if needed to fit.
*
* This may read/return more data than the given len. Use msg_buf_drain the remove the data from the buffer once you
* have used it.
*
* Returns the number of new bytes read, zero for transport read buffer empty, -err_t for error.
*/
ssize_t msg_buf_read (struct msg_buf *buf, transport_t *transport, size_t len, error_t *err)
{
ssize_t ret;
// clamp size
if (len < MSG_BUF_MIN_SIZE)
len = MSG_BUF_MIN_SIZE;
// ensure space
if ((ERROR_CODE(err) = msg_buf_grow(buf, buf->off + len)))
goto error;
// read
if ((ret = transport_read(transport, buf->base + buf->off, len, err)) < 0)
goto error;
// no data left?
if (!ret)
return 0;
// update offset
buf->off += ret;
// ok
return ret;
error:
return -ERROR_CODE(err);
}
/**
* Drives transport_write on the given data until all the given data is written, or zero is returned.
*
* @param transport transport to write to
* @param data input data
* @param len number of bytes to write from data
* @param err returned error info
* @return number of bytes written (which may be zero or less than len), or -err_t.
*/
static ssize_t _transport_write_all (transport_t *transport, const char *data, size_t len, error_t *err)
{
ssize_t ret;
size_t written = 0;
while (len) {
// try and write out remaining data
if ((ret = transport_write(transport, data, len, err)) < 0)
goto error;
if (!ret) {
// write buffer full
break;
} else {
// update and continue
written += ret;
data += ret;
len -= ret;
}
}
// ok
return written;
error:
return -ERROR_CODE(err);
}
/**
* If the buffer is empty, this will attempt to write the given data directly using transport_write until either all
* the data is written (in which case nothing more needs to be done), or the transport won't accept any more writes,
* in which case the remaining data will be buffered.
*
* If the buffer is not empty, then the given data will be added to the end of the buffer, since otherwise the order of
* data would be broken.
*
* In either case, transport_write semantics garuntee that our buffer will either be empty, or an on_write will be
* pending on the transport. See msg_buf_flush() for how to handle transport_callbacks::on_write.
*/
err_t msg_buf_write (struct msg_buf *buf, transport_t *transport, const void *data_ptr, size_t len, error_t *err)
{
ssize_t ret;
const char *data = data_ptr;
if (!buf->off) {
// no data buffered, so we can try and write directly
if ((ret = _transport_write_all(transport, data, len, err)) < 0)
goto error;
// update written
data += ret;
len -= ret;
if (len == 0)
// wrote it all
return SUCCESS;
}
// ensure space
if ((ERROR_CODE(err) = msg_buf_grow(buf, buf->off + len)))
goto error;
// store
memcpy(buf->base + buf->off, data, len);
// update
buf->off += len;
// ok
return SUCCESS;
error:
return ERROR_CODE(err);
}
/**
* Flush buffered write data to the transport, driving transport_write() until either all of our bufferd data has been
* written, or the transport will not accept any more.
*
* In either case, transport_write semantics garuntee that our buffer will either be empty, or an on_write will be
* pending on the transport.
*/
err_t msg_buf_flush (struct msg_buf *buf, transport_t *transport, error_t *err)
{
ssize_t ret;
// write
if ((ret = _transport_write_all(transport, buf->base, buf->off, err)) < 0)
goto error;
if (ret)
// unbuffer the written data
msg_buf_drain(buf, ret);
// ok
return SUCCESS;
error:
return ERROR_CODE(err);
}
/**
* Deinitialize msg_buf to release allocated buffers
*/
void msg_buf_deinit (struct msg_buf *buf)
{
// release
free(buf->base);
// reset
buf->base = NULL;
buf->size = buf->off = 0;
}
/**
* Message header
*/
struct msg_header {
/** Message length, including header */
uint16_t len;
};
/**
* Message header size
*/
#define MSG_PROTO_HEADER_SIZE (sizeof(uint16_t))
/**
* Our state struct
*/
struct msg_proto {
/** The transport */
transport_t *transport;
/** User callbacks */
const struct msg_proto_callbacks *cb_tbl;
/** User callback argument */
void *cb_arg;
/** Input buffer */
struct msg_buf in;
/** Output buffer */
struct msg_buf out;
};
/**
* Signal error to user
*/
static void msg_proto_error (struct msg_proto *proto, const error_t *err)
{
// invoke user callback
proto->cb_tbl->on_error(proto, err, proto->cb_arg);
}
/**
* Attempt to read the current header from our input buffer.
*
* Returns >0 for full header, 0 for incomplete header, -err_t for error.
*/
static int msg_proto_peek_header (struct msg_proto *proto, struct msg_header *header, error_t *err)
{
if (proto->in.off < MSG_PROTO_HEADER_SIZE)
// not enough data for header
return 0;
// read header
header->len = ntohs(*((uint16_t *) proto->in.base));
// bad header?
if (header->len < MSG_PROTO_HEADER_SIZE)
JUMP_SET_ERROR_STR(err, ERR_MISC, "message_header::len");
// ok, got header
return 1;
error:
return -ERROR_CODE(err);
}
/**
* Recieved a message with the given header, and a pointer to the message data
*
* XXX: what to do if the user callback destroys the msg_proto?
*/
static err_t msg_proto_on_msg (struct msg_proto *proto, struct msg_header *header, char *data, error_t *err)
{
(void) err;
// invoke user callback
proto->cb_tbl->on_msg(proto, data, header->len - MSG_PROTO_HEADER_SIZE, proto->cb_arg);
// XXX: handle user errors
return SUCCESS;
}
static void msg_proto_on_read (transport_t *transport, void *arg)
{
struct msg_proto *proto = arg;
struct msg_header header;
ssize_t ret;
error_t err;
// we might be able to read more than one message per event
do {
// try and read message length for incomplete message
if ((ret = msg_proto_peek_header(proto, &header, &err)) < 0)
goto error;
// need to read more data?
if (!ret || header.len > proto->in.off) {
// msg_buf_read a minimum size, so passing a zero is OK
size_t to_read = ret ? header.len : 0;
// read into our buffer
if ((ret = msg_buf_read(&proto->in, transport, to_read, &err)) < 0)
goto error;
} else {
// handle full message
if (msg_proto_on_msg(proto, &header, proto->in.base + MSG_PROTO_HEADER_SIZE, &err))
goto error;
// remove the data from the buffer
msg_buf_drain(&proto->in, header.len);
}
} while (ret);
// ok
return;
error:
// notify user
msg_proto_error(proto, &err);
}
static void msg_proto_on_write (transport_t *transport, void *arg)
{
struct msg_proto *proto = arg;
error_t err;
// flush
if (msg_buf_flush(&proto->out, transport, &err))
// notify user on transport errors
msg_proto_error(proto, &err);
}
static void msg_proto_on_error (transport_t *transport, const error_t *err, void *arg)
{
struct msg_proto *proto = arg;
(void) transport;
// report to user
msg_proto_error(proto, err);
}
static const struct transport_callbacks msg_proto_transport_callbacks = {
.on_read = msg_proto_on_read,
.on_write = msg_proto_on_write,
.on_error = msg_proto_on_error,
};
err_t msg_proto_create (struct msg_proto **proto_ptr, transport_t *transport, const struct msg_proto_callbacks *cb_tbl, void *cb_arg, error_t *err)
{
struct msg_proto *proto;
// alloc
if ((proto = calloc(1, sizeof(*proto))) == NULL)
return ERR_MEM;
// store
proto->transport = transport;
proto->cb_tbl = cb_tbl;
proto->cb_arg = cb_arg;
// init
if (
(ERROR_CODE(err) = msg_buf_init(&proto->in, 0))
|| (ERROR_CODE(err) = msg_buf_init(&proto->out, 0))
)
goto error;
// setup transport
if ((ERROR_CODE(err) = transport_events(transport, TRANSPORT_READ | TRANSPORT_WRITE)))
goto error;
transport_set_callbacks(transport, &msg_proto_transport_callbacks, proto);
// ok
*proto_ptr = proto;
return SUCCESS;
error:
// release
msg_proto_destroy(proto);
return ERROR_CODE(err);
}
/**
* Build and write out the data for the given header
*/
static err_t msg_proto_write_header (struct msg_proto *proto, const struct msg_header *header, error_t *err)
{
char buf[MSG_PROTO_HEADER_SIZE];
// validate
if (header->len < MSG_PROTO_HEADER_SIZE)
return SET_ERROR(err, ERR_MISC);
// build
*((uint16_t *) buf) = htons(header->len);
// write
return msg_buf_write(&proto->out, proto->transport, buf, sizeof(buf), err);
}
err_t msg_proto_send (struct msg_proto *proto, const void *data, size_t len, error_t *err)
{
struct msg_header header;
// build header
header.len = MSG_PROTO_HEADER_SIZE + len;
// write it
if (
msg_proto_write_header(proto, &header, err)
|| msg_buf_write(&proto->out, proto->transport, data, len, err)
)
return ERROR_CODE(err);
// ok
return SUCCESS;
}
void msg_proto_destroy (struct msg_proto *proto)
{
// drop buffers
msg_buf_deinit(&proto->in);
msg_buf_deinit(&proto->out);
// kill transport
transport_destroy(proto->transport);
// release ourself
free(proto);
}