src/sock_gnutls.c
author Tero Marttila <terom@fixme.fi>
Tue, 28 Apr 2009 23:10:30 +0300
branchnew-transport
changeset 159 d3e253d7281a
parent 156 6534a4ac957b
child 163 27a112d89a73
permissions -rw-r--r--
implement heirarchial type-checking for transport_check

#include "sock_gnutls.h"

// XXX: remove
#include "log.h"

#include <gnutls/x509.h>

#include <stdlib.h>
#include <string.h>
#include <time.h>

#include <assert.h>

/**
 * Enable the TCP events based on the session's gnutls_record_get_direction().
 */
static err_t sock_gnutls_ev_enable (struct sock_gnutls *sock, error_t *err)
{
    int ret;
    short mask;

    // gnutls_record_get_direction tells us what I/O operation gnutls would have required for the last
    // operation, so we can use that to determine what events to register
    switch ((ret = gnutls_record_get_direction(sock->session))) {
        case 0: 
            // read more data
            mask = TRANSPORT_READ;
            break;
        
        case 1:
            // write buffer full
            mask = TRANSPORT_WRITE;
            break;
        
        default:
            // random error
            RETURN_SET_ERROR_EXTRA(err, ERR_GNUTLS_RECORD_GET_DIRECTION, ret);
    }
    
    // do the enabling
    if ((ERROR_CODE(err) = transport_fd_enable(SOCK_GNUTLS_FD(sock), mask)))
        return ERROR_CODE(err);
    

    return SUCCESS;
}

/**
 * Translate a set of gnutls_certificate_status_t values to a constant error message
 */
static const char* sock_gnutls_verify_error (unsigned int status)
{
    if (status & GNUTLS_CERT_REVOKED)
        return "certificate was revoked";

    else if (status & GNUTLS_CERT_INVALID) {
        if (status & GNUTLS_CERT_SIGNER_NOT_FOUND)
            return "certificate signer was not found";

        else if (status & GNUTLS_CERT_SIGNER_NOT_CA)
            return "certificate signer is not a Certificate Authority";

        else if (status & GNUTLS_CERT_INSECURE_ALGORITHM)
            return "certificate signed using an insecure algorithm";

        else
            return "certificate could not be verified";

    } else
        return "unknown error";

}

/**
 * Perform the certificate validation procedure on the socket.
 *
 * Based on the GnuTLS examples/ex-rfc2818.c
 */
static err_t sock_gnutls_verify (struct sock_gnutls *sock, error_t *err)
{
    unsigned int status;
    const gnutls_datum_t *cert_list;
    unsigned int cert_list_size;
    gnutls_x509_crt_t cert = NULL;
    time_t t, now;

    // init
    RESET_ERROR(err);
    now = time(NULL);
    
    // inspect the peer's cert chain using the installed trusted CAs
    if ((ERROR_EXTRA(err) = gnutls_certificate_verify_peers2(sock->session, &status)))
        JUMP_SET_ERROR(err, ERR_GNUTLS_CERT_VERIFY_PEERS2);

    // verify errors?
    if (status)
        JUMP_SET_ERROR_STR(err, ERR_GNUTLS_CERT_VERIFY, sock_gnutls_verify_error(status));
    
    // import the main cert
    assert(gnutls_certificate_type_get(sock->session) == GNUTLS_CRT_X509);

    if ((ERROR_EXTRA(err) = gnutls_x509_crt_init(&cert)))
        JUMP_SET_ERROR_STR(err, ERR_GNUTLS_CERT_VERIFY, "gnutls_x509_crt_init");

    if ((cert_list = gnutls_certificate_get_peers(sock->session, &cert_list_size)) == NULL)
        JUMP_SET_ERROR_STR(err, ERR_GNUTLS_CERT_VERIFY, "gnutls_certificate_get_peers");

    if (!cert_list_size)
        JUMP_SET_ERROR_STR(err, ERR_GNUTLS_CERT_VERIFY, "cert_list_size");

    if ((ERROR_EXTRA(err) = gnutls_x509_crt_import(cert, &cert_list[0], GNUTLS_X509_FMT_DER)))
        JUMP_SET_ERROR_STR(err, ERR_GNUTLS_CERT_VERIFY, "gnutls_x509_crt_import");
    
    // check expire/activate... not sure if we need to do this
    if ((t = gnutls_x509_crt_get_expiration_time(cert)) == ((time_t) -1) || t < now)
        JUMP_SET_ERROR_STR(err, ERR_GNUTLS_CERT_VERIFY, "gnutls_x509_crt_get_expiration_time");

    if ((t = gnutls_x509_crt_get_activation_time(cert)) == ((time_t) -1) || t > now)
        JUMP_SET_ERROR_STR(err, ERR_GNUTLS_CERT_VERIFY, "gnutls_x509_crt_get_activation_time");
    
    // check hostname
    if (!gnutls_x509_crt_check_hostname(cert, sock->hostname))
        JUMP_SET_ERROR_STR(err, ERR_GNUTLS_CERT_VERIFY, "gnutls_x509_crt_check_hostname");

error:
    // cleanup
    if (cert)
        gnutls_x509_crt_deinit(cert);
    
    // should be SUCCESS
    return ERROR_CODE(err);    
}


/**
 * Our handshake driver. This will execute the next gnutls_handshake step, handling E_AGAIN.
 *
 * This updates the sock_gnutls::handshake state internally, as used by sock_gnutls_event_handler.
 *
 * If the sock is marked as verify, this will perform the verification, returning on any errors, and then unset the
 * verify flag - this ensures that the peer cert is only verified once per connection...
 *
 * @return >0 for finished handshake, 0 for handshake-in-progress, -err_t for errors.
 */
static int sock_gnutls_handshake (struct sock_gnutls *sock, error_t *err)
{
    int ret;

    // perform the handshake
    if ((ret = gnutls_handshake(sock->session)) < 0 && ret != GNUTLS_E_AGAIN)
        JUMP_SET_ERROR_EXTRA(err, ERR_GNUTLS_HANDSHAKE, ret);
    
    // complete?
    if (ret == 0) {
        // update state
        sock->handshake = false;

        // verify?
        if (sock->verify) {
            // perform the validation
            if (sock_gnutls_verify(sock, err))
                goto error;
            
            // unmark
            sock->verify = false;
        }

        // handshake done
        return 1;

    } else {
        // set state, isn't really needed every time, but easier this way
        sock->handshake = true;

        // re-enable the event for the next iteration
        return sock_gnutls_ev_enable(sock, err);
    }

error:
    return -ERROR_CODE(err);    
}

/**
 * Our transport_fd event handler. Drive the handshake if that's current, otherwise, invoke user callbacks.
 */
static void sock_gnutls_on_event (struct transport_fd *fd, short what, void *arg)
{
    struct sock_gnutls *sock = arg;
    error_t err;

    (void) fd;

    // XXX: timeouts
    (void) what;

    // are we in the handshake cycle?
    if (sock->handshake) {
        RESET_ERROR(&err);

        // perform the next handshake step
        if (sock_gnutls_handshake(sock, &err) == 0) {
            // handshake continues
        
        } else if (SOCK_GNUTLS_TRANSPORT(sock)->connected) {
            // the async connect process has now completed, either succesfully or with an error
            // invoke the user connect callback directly with appropriate error
            transport_connected(SOCK_GNUTLS_TRANSPORT(sock), ERROR_CODE(&err) ? &err : NULL, true);

        } else {
            if (ERROR_CODE(&err))
                // the re-handshake failed, so this transport is dead
                transport_error(SOCK_GNUTLS_TRANSPORT(sock), &err);
        
            else
                // re-handshake completed, so continue with the transport_callbacks
                transport_fd_invoke(SOCK_GNUTLS_FD(sock), what);
        }

    } else {
        // normal sock_stream operation
        // gnutls might be able to proceed now, so invoke user callbacks
        transport_fd_invoke(SOCK_GNUTLS_FD(sock), what);
    }
}

static err_t sock_gnutls_read (transport_t *transport, void *buf, size_t *len, error_t *err)
{
    struct sock_gnutls *sock = transport_check(transport, &sock_gnutls_type);
    int ret;
    
    // read gnutls record
    do {
        ret = gnutls_record_recv(sock->session, buf, *len);

    } while (ret == GNUTLS_E_INTERRUPTED);
    
    // errors
    // XXX: E_REHANDSHAKE?
    if (ret < 0 && ret != GNUTLS_E_AGAIN)
        RETURN_SET_ERROR_EXTRA(err, ERR_GNUTLS_RECORD_RECV, ret);
    
    else if (ret == 0)
        return SET_ERROR(err, ERR_READ_EOF);


    // EAGAIN?
    if (ret < 0) {
        *len = 0;

    } else {
        // updated length
        *len = ret;

    }

    return SUCCESS;
}

static err_t sock_gnutls_write (transport_t *transport, const void *buf, size_t *len, error_t *err)
{
    struct sock_gnutls *sock = transport_check(transport, &sock_gnutls_type);
    int ret;
 
    // read gnutls record
    do {
        ret = gnutls_record_send(sock->session, buf, *len);
   
    } while (ret == GNUTLS_E_INTERRUPTED);

    // errors
    if (ret < 0 && ret != GNUTLS_E_AGAIN)
        RETURN_SET_ERROR_EXTRA(err, ERR_GNUTLS_RECORD_RECV, ret);
    
    else if (ret == 0)
        return SET_ERROR(err, ERR_READ_EOF);


    // eagain?
    if (ret < 0) {
        *len = 0;

    } else {
        // updated length
        *len = ret;
    }

    return SUCCESS;
}

static void _sock_gnutls_destroy (transport_t *transport)
{
    struct sock_gnutls *sock = transport_check(transport, &sock_gnutls_type);
    
    // die
    sock_gnutls_destroy(sock);
}

/**
 * Our sock_tcp-invoked connect handler
 */
static void sock_gnutls__connected (transport_t *transport, const error_t *tcp_err)
{
    struct sock_gnutls *sock = transport_check(transport, &sock_gnutls_type);
    error_t err;

    // trap errors to let the user handle them directly
    if (tcp_err)
        JUMP_SET_ERROR_INFO(&err, tcp_err);
    
    // bind default transport functions (recv/send) to use the TCP fd
    gnutls_transport_set_ptr(sock->session, (gnutls_transport_ptr_t) (long int) SOCK_GNUTLS_FD(sock)->fd);

    // add ourselves as the event handler
    if ((ERROR_CODE(&err) = transport_fd_setup(SOCK_GNUTLS_FD(sock), sock_gnutls_on_event, sock)))
        goto error;

    // start handshake
    if (sock_gnutls_handshake(sock, &err))
        // this should complete with SUCCESS if it returns >0
        goto error;

    // ok, so we wait...
    return;

error:
    // tell the user
    transport_connected(transport, &err, true);
}

struct transport_type sock_gnutls_type = {
    .parent                 = &sock_tcp_type,
    .methods                = {
        .read               = sock_gnutls_read,
        .write              = sock_gnutls_write,
        .destroy            = _sock_gnutls_destroy,
        ._connected         = sock_gnutls__connected,
    },
};

/*
 * Global shared anonymous client credentials
 */
static struct sock_ssl_client_cred sock_gnutls_client_cred_anon = { .x509 = NULL, .verify = false, .refcount = 0 };

// XXX: GnuTLS log func
void _log (int level, const char *msg)
{
    printf("gnutls: %d: %s", level, msg);
}

err_t sock_gnutls_global_init (error_t *err)
{
    // global init
    if ((ERROR_EXTRA(err) = gnutls_global_init()) < 0)
        return SET_ERROR(err, ERR_GNUTLS_GLOBAL_INIT);

    // initialize the anon client credentials
    if ((ERROR_EXTRA(err) = gnutls_certificate_allocate_credentials(&sock_gnutls_client_cred_anon.x509)) < 0)
        return SET_ERROR(err, ERR_GNUTLS_CERT_ALLOC_CRED);

    // XXX: debug
//    gnutls_global_set_log_function(&_log);
//    gnutls_global_set_log_level(11);

    // done
    return SUCCESS;
}

static void sock_ssl_client_cred_destroy (struct sock_ssl_client_cred *cred)
{
    // simple
    gnutls_certificate_free_credentials(cred->x509);

    free(cred);
}

err_t sock_ssl_client_cred_create (struct sock_ssl_client_cred **ctx_cred,
        const char *cafile_path, bool verify,
        const char *cert_path, const char *pkey_path,
        error_t *err
) {
    struct sock_ssl_client_cred *cred;

    // alloc it
    if ((cred = calloc(1, sizeof(*cred))) == NULL)
        return SET_ERROR(err, ERR_CALLOC);

    // create the cert
    if ((ERROR_EXTRA(err) = gnutls_certificate_allocate_credentials(&cred->x509)) < 0)
        JUMP_SET_ERROR(err, ERR_GNUTLS_CERT_ALLOC_CRED);
    
    // load the trusted ca certs?
    if (cafile_path) {
        // load them
        if ((ERROR_EXTRA(err) = gnutls_certificate_set_x509_trust_file(cred->x509, cafile_path, GNUTLS_X509_FMT_PEM)) < 0)
            JUMP_SET_ERROR(err, ERR_GNUTLS_CERT_SET_X509_TRUST_FILE);

    }

    // set the verify flags?
    cred->verify = verify;
    gnutls_certificate_set_verify_flags(cred->x509, 0);

    // load the client cert?
    if (cert_path || pkey_path) {
        // need both...
        assert(cert_path && pkey_path);

        // load
        if ((ERROR_EXTRA(err) = gnutls_certificate_set_x509_key_file(cred->x509, cert_path, pkey_path, GNUTLS_X509_FMT_PEM)))
            JUMP_SET_ERROR(err, ERR_GNUTLS_CERT_SET_X509_KEY_FILE);
    }

    // ok
    cred->refcount = 1;
    *ctx_cred = cred;

    return SUCCESS;

error:
    // release
    sock_ssl_client_cred_destroy(cred);

    return ERROR_CODE(err);
}

void sock_ssl_client_cred_get (struct sock_ssl_client_cred *cred)
{
    cred->refcount++;
}

void sock_ssl_client_cred_put (struct sock_ssl_client_cred *cred)
{
    if (--cred->refcount == 0)
        sock_ssl_client_cred_destroy(cred);
}

err_t sock_ssl_connect (const struct transport_info *info, transport_t **transport_ptr, 
        const char *hostname, const char *service,
        struct sock_ssl_client_cred *cred,
        error_t *err
    )
{
    struct sock_gnutls *sock = NULL;

    // alloc
    if ((sock = calloc(1, sizeof(*sock))) == NULL)
        return SET_ERROR(err, ERR_CALLOC);

    // initialize base
    transport_init(SOCK_GNUTLS_TRANSPORT(sock), &sock_gnutls_type, info);

    if (!cred) {
        // default credentials
        cred = &sock_gnutls_client_cred_anon;
    
    } else {
        // take a ref
        sock->cred = cred;
        cred->refcount++;
    };

    // do verify?
    if (cred->verify)
        sock->verify = true;

    // init
    if ((sock->hostname = strdup(hostname)) == NULL)
        JUMP_SET_ERROR(err, ERR_STRDUP);

    // initialize TCP
    sock_tcp_init(SOCK_GNUTLS_TCP(sock));

    // initialize client session
    if ((ERROR_EXTRA(err) = gnutls_init(&sock->session, GNUTLS_CLIENT)) < 0)
        JUMP_SET_ERROR(err, ERR_GNUTLS_INIT);

    // ...default priority stuff
    if ((ERROR_EXTRA(err) = gnutls_set_default_priority(sock->session)))
        JUMP_SET_ERROR(err, ERR_GNUTLS_SET_DEFAULT_PRIORITY);

    // XXX: silly hack for OpenSSL interop
    gnutls_dh_set_prime_bits(sock->session, 512);

    // bind credentials
    if ((ERROR_EXTRA(err) = gnutls_credentials_set(sock->session, GNUTLS_CRD_CERTIFICATE, cred->x509)))
        JUMP_SET_ERROR(err, ERR_GNUTLS_CRED_SET);

    // TCP connect
    if (sock_tcp_connect_async(SOCK_GNUTLS_TCP(sock), hostname, service, err))
        goto error;

    // done, wait for the connect to complete
    *transport_ptr = SOCK_GNUTLS_TRANSPORT(sock);

    return SUCCESS;

error:
    // cleanup
    sock_gnutls_destroy(sock);

    return ERROR_CODE(err);    
}

void sock_gnutls_destroy (struct sock_gnutls *sock)
{
    // close the session rudely
    gnutls_deinit(sock->session);
 
    // terminate the TCP transport
    sock_tcp_destroy(SOCK_GNUTLS_TCP(sock));
   
    if (sock->cred)
        // drop the cred ref
        sock_ssl_client_cred_put(sock->cred);

    // free
    free(sock->hostname);
}