implement __cinit__, shutdown, close, __dealloc__ for socket, and also add a try-except to not leak client sock from accept()
authorTero Marttila <terom@fixme.fi>
Sun, 16 Aug 2009 21:54:46 +0300
changeset 13 a1091632a8a7
parent 12 314d47bdd4d9
child 14 c44754cc1ffe
implement __cinit__, shutdown, close, __dealloc__ for socket, and also add a try-except to not leak client sock from accept()
qmsk/net/libc.pxd
qmsk/net/socket/socket.pxd
qmsk/net/socket/socket.pyx
--- a/qmsk/net/libc.pxd	Sun Aug 16 21:13:36 2009 +0300
+++ b/qmsk/net/libc.pxd	Sun Aug 16 21:54:46 2009 +0300
@@ -39,6 +39,8 @@
     ssize_t read (int fd, void *buf, size_t count)
     ssize_t write (int fd, void *buf, size_t count)
 
+    int close (int fd)
+
 cdef extern from "alloca.h" :
     void* alloca (size_t size)
 
--- a/qmsk/net/socket/socket.pxd	Sun Aug 16 21:13:36 2009 +0300
+++ b/qmsk/net/socket/socket.pxd	Sun Aug 16 21:54:46 2009 +0300
@@ -9,16 +9,25 @@
         Represents a single OS-level socket
         
         >>> from qmsk.net.socket import addr
-        >>> s = socket()
+        >>> from qmsk.net.socket.constants import *
+
+        >>> s = socket(1337)
+        Traceback (most recent call last):
+          ...
+        OSError: [Errno 97] Address family not supported by protocol
+
+        >>> s = socket(fd=1337)
         >>> s.send('foo')
         Traceback (most recent call last):
           ...
         OSError: [Errno 9] Bad file descriptor
-        >>> s.socket()
+
+        >>> s = socket(AF_INET, SOCK_STREAM)
         >>> s.bind(addr.sockaddr_in('127.0.0.1', 1337))
         >>> s.listen(1)
         >>> s.listen(0)
-        >>> s = socket(); s.socket()
+
+        >>> s = socket()
         >>> s.connect(addr.sockaddr_in('127.0.0.1', 1338))
         Traceback (most recent call last):
           ...
--- a/qmsk/net/socket/socket.pyx	Sun Aug 16 21:13:36 2009 +0300
+++ b/qmsk/net/socket/socket.pyx	Sun Aug 16 21:54:46 2009 +0300
@@ -37,27 +37,30 @@
 # XXX: do some GIL-releasin'
 cdef class socket :
 
-    def __init__ (self, int fd = -1) :
+    def __cinit__ (self) :
         """
-            Construct this socket with the given fd, or -1 to mark it as fd-less
+            Initialize the socket to set fd to -1, so that we dont't try and close stdin too often :)
         """
 
-        self.fd = fd
+        self.fd = -1
 
-    def socket (self, int family = platform.AF_INET, int socktype = platform.SOCK_STREAM, int protocol = 0) :
+    def __init__ (self, int family = platform.AF_INET, int socktype = platform.SOCK_STREAM, int protocol = 0, int fd = -1) :
         """
-            Create a new socket endpoint with the given family/domain, socktype and optionally, specific protocol.
+            Create a new socket endpoint with the given family/domain, socktype and optionally, specific protocol,
+            unless the fd argument is given as >= 0, in which case it used directly.
 
                 family      - one of AF_*
                 socktype    - one of SOCK_*
                 protocol    - one of IPPROTO_* or zero to select default
         """
 
-        if self.fd >= 0 :
-            raise Exception("Socket fd already exists")
-        
-        # socket()
-        self.fd = platform.socket(family, socktype, protocol)
+        if fd >= 0 :
+            # given fd
+            self.fd = fd
+
+        else :
+            # socket()
+            self.fd = platform.socket(family, socktype, protocol)
         
         # trap
         if self.fd < 0 :
@@ -141,9 +144,16 @@
 
         if sock_fd < 0 :
             raise_errno('accept')
+        
+        try :
+            # prep the new socket
+            sock_obj = socket(sock_fd)
 
-        # prep the new socket
-        sock_obj = socket(sock_fd)
+        except :
+            # XXX: don't leak the socket fd? How does socket.__init__ handle this?
+            platform.close(sock_fd)
+
+            raise
 
         # prep the new addr
         sock_addr = build_sockaddr(<platform.sockaddr *> &ss, ss_len)
@@ -303,7 +313,6 @@
         for i, buf in enumerate(iov) :
             iovec = &iov_list[i]
                 
-                
             parse_buf(&iovec.iov_base, &iovec.iov_len, buf, 1)
             
         # sendmsg()
@@ -314,4 +323,58 @@
 
         else :
             return ret
- 
+
+
+    def shutdown (self, how) :
+        """
+            Shutdown part of a full-duplex connection. 
+
+                how         - one of SHUT_*
+
+            This does not affect this socket's fd.
+        """
+        
+        # shutdown()
+        if platform.shutdown(self.fd, how) :
+            raise_errno('shutdown')
+
+    def close (self) :
+        """
+            Close the socket fd if we have one, invalidating it if succesful.
+
+            Note that this will raise an error and keep the fd if the system close() returns an error.
+
+            Calling this again after a succesfull close() does nothing.
+
+            XXX: SO_LINGER/blocking?
+
+            >>> s = socket()
+            >>> s.fd >= 0
+            True
+            >>> s.close()
+            >>> s.fd >= 0
+            False
+            >>> s.close()
+        """
+        
+        # ignore if already closed
+        if self.fd < 0 :
+            return
+        
+        # close()
+        if libc.close(self.fd) :
+            raise_errno('close')
+        
+        # invalidate
+        self.fd = -1
+    
+    def __dealloc__ (self) :
+        """
+            Close the socket fd if one is set, ignoring any errors from close
+        """
+
+        if self.fd >= 0 :
+            if libc.close(self.fd) :
+                # XXX: at least warn... ?
+                pass
+