qmsk/net/socket/socket.pyx
changeset 16 24ce1035b338
parent 15 8e7037cc05c7
child 17 5f4077a530b0
--- a/qmsk/net/socket/socket.pyx	Sun Aug 16 23:45:43 2009 +0300
+++ b/qmsk/net/socket/socket.pyx	Mon Aug 17 00:45:58 2009 +0300
@@ -35,6 +35,44 @@
     else :
         raise ValueError(buf)
 
+## helper for socket.recv*
+# XXX: make sure these don't leak the PyString in case of errors...
+cdef struct sockbuf :
+    py.PyObject *str
+
+cdef char* sockbuf_init (sockbuf *buf, size_t len) :
+    """
+        Initialize the sockbuf to contain a PyString that can hold `len` bytes, and return a pointer into its
+        contents as a char*.
+
+        Note that this requires use of a try-finally with sockbuf_deinit...
+    """
+
+    buf.str = py.PyString_FromStringAndSize(NULL, len)
+    
+    return py.PyString_AS_STRING(buf.str)
+
+cdef object sockbuf_truncate (sockbuf *buf, size_t len) :
+    """
+        Truncate the given sockbuf's PyString to the given length, and return the PyObject*
+    """
+    
+    # optimize for the no-need-to-resize case
+    # this also fixes behaviour for zero-length strings (heh), since they're interned and can't be resized
+    if len != py.PyString_GET_SIZE(buf.str) :
+        py._PyString_Resize(&buf.str, len)
+    
+    return <object> buf.str
+
+cdef void sockbuf_deinit (sockbuf *buf) :
+    """
+        Release the PyObject.
+
+        This is safe if the sockbuf was initialized to zero.
+    """
+
+    py.Py_XDECREF(buf.str)
+
 # XXX: do some GIL-releasin'
 cdef class socket :
 
@@ -339,25 +377,25 @@
 
             Returns the recieved data as a newly allocated string of the correct length.
         """
-        
-        # alloc a new return str
-        # XXX: len overflow...
-        cdef object str = py.PyString_FromStringAndSize(NULL, len)
-        cdef char *buf = py.PyString_AS_STRING(str)
-        
-        # recv()
-        cdef libc.ssize_t ret = platform.recv(self.fd, buf, len, flags)
 
-        if ret < 0 :
-            raise_errno('recv')
+        cdef sockbuf sb
+        cdef libc.ssize_t ret
 
-        # truncate to correct length
-        # XXX: refcounts?
-        cdef py.PyObject *str_obj = <py.PyObject *> str
+        # alloc the recv buffer
+        cdef char *buf = sockbuf_init(&sb, len)
+        
+        try :
+            # recv()
+            ret = platform.recv(self.fd, buf, len, flags)
 
-        py._PyString_Resize(&str_obj, ret)
-        
-        return <object> str_obj
+            if ret < 0 :
+                raise_errno('recv')
+
+            # truncate to correct length
+            return sockbuf_truncate(&sb, ret)
+
+        finally :
+            sockbuf_deinit(&sb)
     
     def recvfrom (self, size_t len, int flags = 0) :
         """
@@ -371,26 +409,35 @@
                 src_addr    - the source address the message was recieved from
         """
 
-        # alloc a new return str
-        # XXX: len overflow...
-        cdef object str = py.PyString_FromStringAndSize(NULL, len)
-        cdef char *buf = py.PyString_AS_STRING(str)
-        
+        cdef sockbuf sb
+        cdef libc.ssize_t ret
+        cdef object str
+        cdef sockaddr src_addr
+            
         # prep the sockaddr that we will return
         cdef platform.sockaddr_storage ss
         cdef platform.socklen_t ss_len = sizeof(ss)
-        
-        # recvfrom()
-        cdef libc.ssize_t ret = platform.recvfrom(self.fd, buf, len, flags, <platform.sockaddr *> &ss, &ss_len)
 
-        if ret < 0 :
-            raise_errno('recv')
-        
-        # prep the new addr
-        cdef sock_addr = build_sockaddr(<platform.sockaddr *> &ss, ss_len)
-        
-        # XXX: figure out how to call _PyString_Resize
-        return str[:ret], sock_addr
+        # alloc recv buf
+        cdef char *buf = sockbuf_init(&sb, len)
+
+        try :
+            # recvfrom()
+            ret = platform.recvfrom(self.fd, buf, len, flags, <platform.sockaddr *> &ss, &ss_len)
+
+            if ret < 0 :
+                raise_errno('recv')
+
+            # truncate
+            str = sockbuf_truncate(&sb, ret)
+            
+            # prep the new addr
+            src_addr = build_sockaddr(<platform.sockaddr *> &ss, ss_len)
+            
+            return str, src_addr
+
+        finally :
+            sockbuf_deinit(&sb)
     
     def recvmsg (self, bint recv_addr = True, object iov_lens = None, size_t control_len = 0, int flags = 0) :
         """
@@ -401,16 +448,24 @@
                 control_len - (optional) amount of auxiliary data to recieve
                 flags       - (optional) flags to pass to recvmsg()
 
-            Returns a (name, iov, control, flags) tuple :
+            Returns a (name, iovs, control, flags) tuple :
                 name        - the source address of the message, or None
-                iov         - sequence of strings containing the recieved data, each at most lenv[x] bytes long
+                iovs        - sequence of strings containing the recieved data, each at most lenv[x] bytes long
                 control     - string containing recieved control message, if any
                 flags       - recieved flags
         """
 
         cdef platform.msghdr msg
+        cdef sockbuf *sb_list, *sb, cmsg_sb
+        cdef libc.iovec *iovec
+        cdef size_t iov_len, i, msg_len
+        cdef libc.ssize_t ret
+        cdef sockaddr name = None 
+        cdef object iovs = None
+        cdef object control = None
 
         libc.memset(&msg, 0, sizeof(msg))
+        libc.memset(&cmsg_sb, 0, sizeof(cmsg_sb))
 
         # prep the sockaddr that we will return
         cdef platform.sockaddr_storage ss
@@ -419,46 +474,89 @@
         if recv_addr :
             msg.msg_name = <void *> &ss
             msg.msg_namelen = sizeof(ss)
-
-        # build iov?
-        if iov_lens :
-            # XXX: implement
-            pass
-
-        # build control buffer?
-        if control_len :
-            # XXX: implement
-            pass
-
-        # recvmsg()
-        cdef libc.ssize_t ret = platform.recvmsg(self.fd, &msg, flags)
-
-        if ret < 0 :
-            raise_errno('recvmsg')
         
-        # name?
-        cdef sockaddr name = None
+        try :
+            # build iov?
+            if iov_lens :
+                # stabilize
+                iov_lens = tuple(iov_lens)
+                
+                msg.msg_iovlen = len(iov_lens)
+                
+                # alloc each iov plus a sockbuf for storing the PyString
+                msg.msg_iov = <libc.iovec *> libc.alloca(msg.msg_iovlen * sizeof(libc.iovec))
+                sb_list = <sockbuf *> libc.alloca(msg.msg_iovlen * sizeof(sockbuf))
+                
+                # zero out so we can cleanup
+                libc.memset(sb_list, 0, msg.msg_iovlen * sizeof(sockbuf))
+                
+                # build each
+                for i, iov_len in enumerate(iov_lens) :
+                    # the associated iovec/sockbuf
+                    iovec = &msg.msg_iov[i]
+                    sb = &sb_list[i]
+                    
+                    # set up the sockbuf and iovec
+                    iovec.iov_base = sockbuf_init(&sb_list[i], iov_len)
+                    iovec.iov_len = iov_len
 
-        if msg.msg_name and msg.msg_namelen :
-            name = build_sockaddr(<platform.sockaddr *> msg.msg_name, msg.msg_namelen)
+            # build control buffer?
+            if control_len :
+                msg.msg_control = sockbuf_init(&cmsg_sb, control_len)
+                msg.msg_controllen = control_len
 
-        # iov?
-        cdef object iov = None
+            # recvmsg()
+            ret = platform.recvmsg(self.fd, &msg, flags)
 
-        if ret :
-            assert msg.msg_iov and msg.msg_iovlen
+            if ret < 0 :
+                raise_errno('recvmsg')
             
-            # XXX: implement
-            pass
+            # name?
+            if msg.msg_name and msg.msg_namelen :
+                # build a sockaddr for the name
+                name = build_sockaddr(<platform.sockaddr *> msg.msg_name, msg.msg_namelen)
 
-        # control?
-        cdef object control = None
+            # iov?
+            if ret :
+                assert msg.msg_iov and msg.msg_iovlen
+                
+                iovs = []
+                msg_len = ret
+                i = 0
+                
+                # consume iov's until we have all the data we need
+                while msg_len :
+                    # sanity-check
+                    assert i < msg.msg_iovlen
 
-        if msg.msg_control and msg.msg_controllen :
-            # XXX: implement
-            pass
+                    # get the associated iovec/sockbuf
+                    iovec = &msg.msg_iov[i]
+                    sb = &sb_list[i]
+                    
+                    # calc the size of this iov
+                    # XXX: cdef
+                    iov_len = min(msg_len, iovec.iov_len)
 
-        return name, iov, control, msg.msg_flags
+                    # add it as a string
+                    iovs.append(sockbuf_truncate(sb, iov_len))
+
+                    # advance
+                    msg_len -= iov_len
+                    i += 1
+
+            # control?
+            if msg.msg_control and msg.msg_controllen :
+                # build the PyString for the control message
+                control = sockbuf_truncate(&cmsg_sb, msg.msg_controllen)
+
+            return name, iovs, control, msg.msg_flags
+
+        finally :
+            # cleanup
+            sockbuf_deinit(&cmsg_sb)
+
+            for i in range(msg.msg_iovlen) :
+                sockbuf_deinit(&sb_list[i])
 
     def shutdown (self, how) :
         """