src/lib/msg_proto.c
author Tero Marttila <terom@fixme.fi>
Thu, 28 May 2009 01:17:36 +0300
branchnew-lib-errors
changeset 219 cefec18b8268
parent 196 src/msg_proto.c@873796250c60
permissions -rw-r--r--
some of the lib/transport stuff compiles
#include "msg_proto.h"

#include <string.h>
#include <stdint.h>
#include <arpa/inet.h>

/**
 * I/O buffer
 */
struct msg_buf {
    /** Buffer base pointer */
    char *base;

    /** Size of the buffer */
    size_t size;

    /** Current read/write offset */
    size_t off;
};

/**
 * The minimum size used for any msg_buf::size related operation.
 */
#define MSG_BUF_MIN_SIZE 1024

/**
 * Growth rate for size
 */
#define MSG_BUF_GROW_RATE 2

/**
 * Initialize a message buffer at the given initial size
 */
err_t msg_buf_init (struct msg_buf *buf, size_t hint)
{
    // apply minimum size
    if (hint < MSG_BUF_MIN_SIZE)
        hint = MSG_BUF_MIN_SIZE;

    // allocate the initial buffer
    if ((buf->base = malloc(hint)) == NULL)
        return ERR_MEM;
    
    // set fields
    buf->size = hint;
    buf->off = 0;

    // ok
    return SUCCESS;
}

/**
 * Grow the buffer if needed to fit the given capacity.
 */
err_t msg_buf_grow (struct msg_buf *buf, size_t size)
{
    char *tmp = buf->base;

    if (buf->size >= size)
        // nothing to do
        return SUCCESS;

    // calculate new size
    while (buf->size < size)
        buf->size *= MSG_BUF_GROW_RATE;

    // resize
    if ((buf->base = realloc(buf->base, buf->size)) == NULL) {
        buf->base = tmp;

        return ERR_MEM;
    }

    // ok
    return SUCCESS;
}

/**
 * Drain \a len bytes off the head of the buffer
 */
err_t msg_buf_drain (struct msg_buf *buf, size_t len)
{
    // simple memmove
    memmove(buf->base, buf->base + len, buf->off - len);

    // update offfset
    buf->off -= len;
    
    // ok
    return SUCCESS;
}

/**
 * Read into the buffer from a transport_t.
 *
 * This will attempt to read \a len bytes onto the end of the buffer, growing it if needed to fit.
 *
 * This may read/return more data than the given len. Use msg_buf_drain the remove the data from the buffer once you
 * have used it.
 *
 * Returns the number of new bytes read, zero for transport read buffer empty, -err_t for error.
 */
ssize_t msg_buf_read (struct msg_buf *buf, transport_t *transport, size_t len, error_t *err)
{
    ssize_t ret;

    // clamp size
    if (len < MSG_BUF_MIN_SIZE)
        len = MSG_BUF_MIN_SIZE;

    // ensure space
    if ((ERROR_CODE(err) = msg_buf_grow(buf, buf->off + len)))
        goto error;

    // read
    if ((ret = transport_read(transport, buf->base + buf->off, len, err)) < 0)
        goto error;
    
    // no data left?
    if (!ret)
        return 0;

    // update offset
    buf->off += ret;

    // ok
    return ret;

error:
    return -ERROR_CODE(err);    
}

/**
 * Drives transport_write on the given data until all the given data is written, or zero is returned.
 *
 * @param transport transport to write to
 * @param data input data
 * @param len number of bytes to write from data
 * @param err returned error info
 * @return number of bytes written (which may be zero or less than len), or -err_t.
 */
static ssize_t _transport_write_all (transport_t *transport, const char *data, size_t len, error_t *err)
{
    ssize_t ret;
    size_t written = 0;

    while (len) {
        // try and write out remaining data
        if ((ret = transport_write(transport, data, len, err)) < 0)
            goto error;
        
        if (!ret) {
            // write buffer full
            break;

        } else { 
            // update and continue
            written += ret;
            data += ret;
            len -= ret;
        }
    }

    // ok
    return written;

error:
    return -ERROR_CODE(err);    
}

/**
 * If the buffer is empty, this will attempt to write the given data directly using transport_write until either all
 * the data is written (in which case nothing more needs to be done), or the transport won't accept any more writes,
 * in which case the remaining data will be buffered.
 *
 * If the buffer is not empty, then the given data will be added to the end of the buffer, since otherwise the order of
 * data would be broken.
 *
 * In either case, transport_write semantics garuntee that our buffer will either be empty, or an on_write will be
 * pending on the transport. See msg_buf_flush() for how to handle transport_callbacks::on_write.
 */
err_t msg_buf_write (struct msg_buf *buf, transport_t *transport, const void *data_ptr, size_t len, error_t *err)
{
    ssize_t ret;
    const char *data = data_ptr;

    if (!buf->off) {
        // no data buffered, so we can try and write directly
        if ((ret = _transport_write_all(transport, data, len, err)) < 0)
            goto error;
    
        // update written
        data += ret;
        len -= ret;
        
        if (len == 0)
            // wrote it all
            return SUCCESS;
    }

    // ensure space
    if ((ERROR_CODE(err) = msg_buf_grow(buf, buf->off + len)))
        goto error;

    // store
    memcpy(buf->base + buf->off, data, len);
    
    // update
    buf->off += len;

    // ok
    return SUCCESS;

error:
    return ERROR_CODE(err);    
}

/**
 * Flush buffered write data to the transport, driving transport_write() until either all of our bufferd data has been
 * written, or the transport will not accept any more.
 *
 * In either case, transport_write semantics garuntee that our buffer will either be empty, or an on_write will be
 * pending on the transport.
 */
err_t msg_buf_flush (struct msg_buf *buf, transport_t *transport, error_t *err)
{
    ssize_t ret;

    // write
    if ((ret = _transport_write_all(transport, buf->base, buf->off, err)) < 0)
        goto error;
    
    if (ret)
        // unbuffer the written data
        msg_buf_drain(buf, ret);

    // ok
    return SUCCESS;

error:
    return ERROR_CODE(err);    
}

/**
 * Deinitialize msg_buf to release allocated buffers
 */
void msg_buf_deinit (struct msg_buf *buf)
{
    // release
    free(buf->base);

    // reset
    buf->base = NULL;
    buf->size = buf->off = 0;
}

/**
 * Message header
 */
struct msg_header {
    /** Message length, including header */
    uint16_t len;
};

/**
 * Message header size
 */
#define MSG_PROTO_HEADER_SIZE (sizeof(uint16_t))

/**
 * Our state struct
 */
struct msg_proto {
    /** The transport */
    transport_t *transport;

    /** User callbacks */
    const struct msg_proto_callbacks *cb_tbl;

    /** User callback argument */
    void *cb_arg;

    /** Input buffer */
    struct msg_buf in;

    /** Output buffer */
    struct msg_buf out;
};

/**
 * Signal error to user
 */
static void msg_proto_error (struct msg_proto *proto, const error_t *err)
{
    // invoke user callback
    proto->cb_tbl->on_error(proto, err, proto->cb_arg);
}

/**
 * Attempt to read the current header from our input buffer.
 *
 * Returns >0 for full header, 0 for incomplete header, -err_t for error.
 */
static int msg_proto_peek_header (struct msg_proto *proto, struct msg_header *header, error_t *err)
{
    if (proto->in.off < MSG_PROTO_HEADER_SIZE)
        // not enough data for header
        return 0;

    // read header
    header->len = ntohs(*((uint16_t *) proto->in.base));

    // bad header?
    if (header->len < MSG_PROTO_HEADER_SIZE)
        JUMP_SET_ERROR_STR(err, ERR_MISC, "message_header::len");

    // ok, got header
    return 1;

error:
    return -ERROR_CODE(err);    
}

/**
 * Recieved a message with the given header, and a pointer to the message data
 *
 * XXX: what to do if the user callback destroys the msg_proto?
 */
static err_t msg_proto_on_msg (struct msg_proto *proto, struct msg_header *header, char *data, error_t *err)
{
    (void) err;

    // invoke user callback
    proto->cb_tbl->on_msg(proto, data, header->len - MSG_PROTO_HEADER_SIZE, proto->cb_arg);

    // XXX: handle user errors
    return SUCCESS;
}

static void msg_proto_on_read (transport_t *transport, void *arg)
{
    struct msg_proto *proto = arg;
    struct msg_header header;
    ssize_t ret;
    error_t err;
    
    // we might be able to read more than one message per event
    do {
        // try and read message length for incomplete message
        if ((ret = msg_proto_peek_header(proto, &header, &err)) < 0)
            goto error;
        
        // need to read more data?
        if (!ret || header.len > proto->in.off) {
            // msg_buf_read a minimum size, so passing a zero is OK
            size_t to_read = ret ? header.len : 0;

            // read into our buffer
            if ((ret = msg_buf_read(&proto->in, transport, to_read, &err)) < 0)
                goto error;
    
        } else {
            // handle full message
            if (msg_proto_on_msg(proto, &header, proto->in.base + MSG_PROTO_HEADER_SIZE, &err))
                goto error;
            
            // remove the data from the buffer
            msg_buf_drain(&proto->in, header.len);
        }
    } while (ret);
    
    // ok
    return;

error:
    // notify user
    msg_proto_error(proto, &err);    
}

static void msg_proto_on_write (transport_t *transport, void *arg)
{
    struct msg_proto *proto = arg;
    error_t err;

    // flush
    if (msg_buf_flush(&proto->out, transport, &err))
        // notify user on transport errors
        msg_proto_error(proto, &err);
}

static void msg_proto_on_error (transport_t *transport, const error_t *err, void *arg)
{
    struct msg_proto *proto = arg;

    (void) transport;

    // report to user
    msg_proto_error(proto, err);
}

static const struct transport_callbacks msg_proto_transport_callbacks = {
    .on_read    = msg_proto_on_read,
    .on_write   = msg_proto_on_write,
    .on_error   = msg_proto_on_error,
};

err_t msg_proto_create (struct msg_proto **proto_ptr, transport_t *transport, const struct msg_proto_callbacks *cb_tbl, void *cb_arg, error_t *err)
{
    struct msg_proto *proto;

    // alloc
    if ((proto = calloc(1, sizeof(*proto))) == NULL)
        return ERR_MEM;

    // store
    proto->transport = transport;
    proto->cb_tbl = cb_tbl;
    proto->cb_arg = cb_arg;

    // init
    if (
            (ERROR_CODE(err) = msg_buf_init(&proto->in, 0))
        ||  (ERROR_CODE(err) = msg_buf_init(&proto->out, 0))
    )
        goto error;

    // setup transport
    if ((ERROR_CODE(err) = transport_events(transport, TRANSPORT_READ | TRANSPORT_WRITE)))
        goto error;

    transport_set_callbacks(transport, &msg_proto_transport_callbacks, proto);

    // ok
    *proto_ptr = proto;

    return SUCCESS;

error:
    // release
    msg_proto_destroy(proto);

    return ERROR_CODE(err);
}

/**
 * Build and write out the data for the given header
 */
static err_t msg_proto_write_header (struct msg_proto *proto, const struct msg_header *header, error_t *err)
{
    char buf[MSG_PROTO_HEADER_SIZE];

    // validate
    if (header->len < MSG_PROTO_HEADER_SIZE)
        return SET_ERROR(err, ERR_MISC);

    // build
    *((uint16_t *) buf) = htons(header->len);

    // write
    return msg_buf_write(&proto->out, proto->transport, buf, sizeof(buf), err);
}

err_t msg_proto_send (struct msg_proto *proto, const void *data, size_t len, error_t *err)
{
    struct msg_header header;

    // build header
    header.len = MSG_PROTO_HEADER_SIZE + len;

    // write it
    if (
            msg_proto_write_header(proto, &header, err)
        ||  msg_buf_write(&proto->out, proto->transport, data, len, err)
    )
        return ERROR_CODE(err);

    // ok
    return SUCCESS;
}

void msg_proto_destroy (struct msg_proto *proto)
{
    // drop buffers
    msg_buf_deinit(&proto->in);
    msg_buf_deinit(&proto->out);

    // kill transport
    transport_destroy(proto->transport);

    // release ourself
    free(proto);
}