src/sock_tcp.c
changeset 3 cc94ae754e2a
parent 2 a834f0559939
child 4 a3ca0f97a075
--- a/src/sock_tcp.c	Sun Feb 22 05:27:29 2009 +0200
+++ b/src/sock_tcp.c	Sun Feb 22 06:44:16 2009 +0200
@@ -7,66 +7,82 @@
 #include <netdb.h>
 #include <unistd.h>
 #include <string.h>
-#include <err.h>
+#include <assert.h>
 
 /*
- * Our sock_stream_type.methods.read implementation
+ * Our sock_stream_methods.read method
  */
-static int sock_tcp_read (struct sock_stream *base_sock, void *buf, size_t len)
+static err_t sock_tcp_read (struct sock_stream *base_sock, void *buf, size_t len)
 {
     struct sock_tcp *sock = SOCK_FROM_BASE(base_sock, struct sock_tcp);
+    int ret;
+    
+    // map directly to read(2)
+    if ((ret = read(sock->fd, buf, len)) < 0)
+        // errno
+        RETURN_SET_ERROR_ERRNO(SOCK_TCP_ERR(sock), ERR_READ);
 
-    return read(sock->fd, buf, len);
+    else 
+        // bytes read
+        return ret;
 }
 
 /*
- * Our sock_stream_type.methods.write implementation
+ * Our sock_stream_methods.write method
  */
-static int sock_tcp_write (struct sock_stream *base_sock, const void *buf, size_t len)
+static err_t sock_tcp_write (struct sock_stream *base_sock, const void *buf, size_t len)
 {
     struct sock_tcp *sock = SOCK_FROM_BASE(base_sock, struct sock_tcp);
+    int ret;
+    
+    // map directly to write(2)
+    if ((ret = write(sock->fd, buf, len)) < 0)
+        // errno
+        RETURN_SET_ERROR_ERRNO(SOCK_TCP_ERR(sock), ERR_WRITE);
 
-    return write(sock->fd, buf, len);
+    else
+        // bytes read
+        return ret;
 }
 
 /*
  * Our sock_stream_type
- *
- * XXX: move to sock_tcp.h
  */
 struct sock_stream_type sock_tcp_type = {
     .methods.read   = &sock_tcp_read,
     .methods.write  = &sock_tcp_write,
 };
 
-struct sock_tcp* sock_tcp_alloc (void)
+err_t sock_tcp_alloc (struct sock_tcp **sock_ptr)
 {
-    struct sock_tcp *sock;
-
     // alloc
-    if ((sock = calloc(1, sizeof(*sock))) == NULL)
-        errx(1, "calloc");
+    if ((*sock_ptr = calloc(1, sizeof(**sock_ptr))) == NULL)
+        return ERR_CALLOC;
     
-    // initialize base
-    sock->base.type = &sock_tcp_type;
+    // initialize base with sock_tcp_type
+    sock_stream_init(SOCK_TCP_BASE(*sock_ptr), &sock_tcp_type);
 
     // done
-    return sock;
+    return SUCCESS;
 }
 
-int sock_tcp_init_fd (struct sock_tcp *sock, int fd)
+err_t sock_tcp_init_fd (struct sock_tcp *sock, int fd)
 {
+    // valid fd -XXX: err instead?
+    assert(fd >= 0);
+
     // initialize
     sock->fd = fd;
 
     // done
-    return 0;
+    return SUCCESS;
 }
 
-int sock_tcp_init_connect (struct sock_tcp *sock, const char *hostname, const char *service)
+err_t sock_tcp_init_connect (struct sock_tcp *sock, const char *hostname, const char *service)
 {
     struct addrinfo hints, *res, *r;
-    int _err;
+    int err;
+    RESET_ERROR(SOCK_TCP_ERR(sock));
     
     // hints
     memset(&hints, 0, sizeof(hints));
@@ -74,40 +90,85 @@
     hints.ai_socktype = SOCK_STREAM;
 
     // resolve
-    if ((_err = getaddrinfo(hostname, service, &hints, &res)))
-        errx(1, "getaddrinfo: %s", gai_strerror(_err));
+    if ((err = getaddrinfo(hostname, service, &hints, &res)))
+        RETURN_SET_ERROR_EXTRA(SOCK_TCP_ERR(sock), ERR_GETADDRINFO, err);
 
-    // use
+    // try each result in turn
     for (r = res; r; r = r->ai_next) {
-        // XXX: wrong
-        if ((sock->fd = socket(r->ai_family, r->ai_socktype, r->ai_protocol)) < 0)
-            err(1, "socket");
+        // create the socket
+        if ((sock->fd = socket(r->ai_family, r->ai_socktype, r->ai_protocol)) < 0) {
+            // remember error
+            SET_ERROR_ERRNO(SOCK_TCP_ERR(sock), ERR_SOCKET);
 
-        if (connect(sock->fd, r->ai_addr, r->ai_addrlen))
-            err(1, "connect");
+            // skip to next one
+            continue;
+        }
+        
+        // connect to remote address
+        if (connect(sock->fd, r->ai_addr, r->ai_addrlen)) {
+            // remember error
+            SET_ERROR_ERRNO(SOCK_TCP_ERR(sock), ERR_CONNECT);
+            
+            // close/invalidate socket
+            close(sock->fd);
+            sock->fd = -1;
 
+            // skip to next one
+            continue;
+        }
+        
+        // valid socket, use this
         break;
     }
     
-    // ensure we got some valid socket
-    if (sock->fd < 0)
-        errx(1, "no valid socket");
+    // ensure we got some valid socket, else return last error code
+    if (sock->fd < 0) {
+        // did we hit some error?
+        if (IS_ERROR(SOCK_TCP_ERR(sock)))
+            // return last error
+            return ERROR_CODE(SOCK_TCP_ERR(sock));
+        
+        else
+            // no results
+            return SET_ERROR(SOCK_TCP_ERR(sock), ERR_GETADDRINFO_EMPTY);
+    }
     
     // ok, done
     return 0;    
 }
 
-// XXX: error handling
-struct sock_stream* sock_tcp_connect (const char *host, const char *service) 
+err_t sock_tcp_connect (struct sock_stream **sock_ptr, const char *host, const char *service, struct error_info *err_info)
 {
     struct sock_tcp *sock;
+    err_t err;
     
     // allocate
-    sock = sock_tcp_alloc();
+    if ((err = sock_tcp_alloc(&sock)))
+        return err;
 
     // connect
-    sock_tcp_init_connect(sock, host, service);
+    if ((err = sock_tcp_init_connect(sock, host, service))) {
+        // set *err_info
+        *err_info = SOCK_TCP_ERR(sock);
 
-    // done
-    return SOCK_TCP_BASE(sock);
+        // cleanup
+        sock_tcp_release(sock);
+        
+        // return error code
+        return err;
+    }
+
+    // good
+    *sock_ptr = SOCK_TCP_BASE(sock);
+
+    return 0;
 }
+
+void sock_tcp_release (struct sock_tcp *sock)
+{
+    // must not be connected
+    assert(sock->fd < 0);
+
+    // free
+    free(sock);
+}