qmsk/net/socket/socket.pyx
author Tero Marttila <terom@fixme.fi>
Mon, 17 Aug 2009 01:29:31 +0300
changeset 20 0e4933d5862e
parent 19 e6b670dbfe3b
child 24 f18b5787c46c
permissions -rw-r--r--
rename qmsk.net.socket.addr -> qmsk.net.socket.address
from qmsk.net.socket.socket cimport *
from qmsk.net.socket.address cimport sockaddr, build_sockaddr

cimport qmsk.net.socket.platform as platform
cimport qmsk.net.libc as libc, qmsk.net.py as py

from qmsk.net.py cimport raise_errno

cdef parse_sockaddr (platform.sockaddr **sa_ptr, platform.socklen_t *sa_len, sockaddr addr, int optional = 0) :
    if addr is not None :
        addr._get_sockaddr(sa_ptr, sa_len)

    elif optional :
        sa_ptr[0] = NULL
        sa_len[0] = 0

    else :
        raise ValueError(addr)

cdef parse_buf (void **buf_ptr, size_t *buf_len, object buf, int optional = 0) :
    cdef libc.ssize_t tmp_len

    if buf is not None :
        # XXX: test that except works right
        # XXX: this complains about const...
        py.PyObject_AsCharBuffer(buf, <char **> buf_ptr, &tmp_len)
        
        # XXX: ensure that this is >= 0
        buf_len[0] = tmp_len
    
    elif optional :
        buf_ptr[0] = NULL
        buf_len[0] = 0

    else :
        raise ValueError(buf)

## helper for socket.recv*
# XXX: make sure these don't leak the PyString in case of errors...
cdef struct sockbuf :
    py.PyObject *str

cdef char* sockbuf_init (sockbuf *buf, size_t len) :
    """
        Initialize the sockbuf to contain a PyString that can hold `len` bytes, and return a pointer into its
        contents as a char*.

        Note that this requires use of a try-finally with sockbuf_deinit...
    """

    buf.str = py.PyString_FromStringAndSize(NULL, len)
    
    return py.PyString_AS_STRING(buf.str)

cdef object sockbuf_truncate (sockbuf *buf, size_t len) :
    """
        Truncate the given sockbuf's PyString to the given length, and return the PyObject*
    """
    
    # optimize for the no-need-to-resize case
    # this also fixes behaviour for zero-length strings (heh), since they're interned and can't be resized
    if len != py.PyString_GET_SIZE(buf.str) :
        py._PyString_Resize(&buf.str, len)
    
    return <object> buf.str

cdef void sockbuf_deinit (sockbuf *buf) :
    """
        Release the PyObject.

        This is safe if the sockbuf was initialized to zero.
    """

    py.Py_XDECREF(buf.str)

# XXX: do some GIL-releasin'
cdef class socket :

    def __cinit__ (self) :
        """
            Initialize the socket to set fd to -1, so that we dont't try and close stdin too often :)
        """

        self.fd = -1

    def __init__ (self, int family = platform.AF_INET, int socktype = platform.SOCK_STREAM, int protocol = 0, int fd = -1) :
        """
            Create a new socket endpoint with the given family/domain, socktype and optionally, specific protocol,
            unless the fd argument is given as >= 0, in which case it used directly.

                family      - one of AF_*
                socktype    - one of SOCK_*
                protocol    - one of IPPROTO_* or zero to select default
        """

        if fd >= 0 :
            # given fd
            self.fd = fd

        else :
            # socket()
            self.fd = platform.socket(family, socktype, protocol)
        
        # trap
        if self.fd < 0 :
            raise_errno('socket')

    def fileno (self) :
        """
            Returns the OS-level file desriptor for this socket as an integer
        """

        return self.fd
    
    def setblocking (self, bint blocking = True) :
        """
            Control the OS-level nonblocking-IO mode flag for this socket.

                blocking    - True for normal blocking operation, False to use non-blocking operation
        """
        
        # fcntl magic
        libc.fcntl_set_flag(self.fd, libc.O_NONBLOCK, not blocking)

    def bind (self, sockaddr addr) :
        """
            Bind this socket to the given local socket address. The given sockaddr should be of the same or a
            compatible address family.

                addr        - the local address to bind to. The port may be zero to let the system choose an unused
                              ephemeral port.
        """

        cdef platform.sockaddr *sa
        cdef platform.socklen_t sa_len
        
        # XXX: require non-NULL addr?
        parse_sockaddr(&sa, &sa_len, addr, 1)

        # bind()
        if platform.bind(self.fd, sa, sa_len) :
            raise_errno('bind')

    def listen (self, int backlog) :
        """
            Listen for connections, marking this socket as a passive socket, which can accept incoming connection
            requests using sock.accept().

            It is customary to call .bind() before .listen().

                backlog     - maximum number of pending connections (those not yet .accept()'d).
        """
        
        # listen()
        if platform.listen(self.fd, backlog) :
            raise_errno('listen')

    def connect (self, sockaddr addr) :
        """
            Initiate a connection, connecting this socket to the remote endpoint specified by `addr`. The given sockaddr
            should be of the same or a compatible address family.

            If the socket is in non-blocking mode, this will presumeably return errno.EINPROGRESS.

            If the socket has not yet been bound (using .bind()), the system will pick an appropriate local address and
            ephemeral port.

                addr        - the remote address to connect to.
        """

        cdef platform.sockaddr *sa
        cdef platform.socklen_t sa_len

        # XXX: require non-NULL addr?
        parse_sockaddr(&sa, &sa_len, addr, 1)
        
        # connect()
        if platform.connect(self.fd, sa, sa_len) :
            raise_errno('connect')

    def accept (self) :
        """
            Accept a connection, dequeueing the first pending connection and returning a new sock object for it. This
            socket must be a connection-based socket (SOCK_STREAM/SOCK_SEQPACKET) and in the passive listening mode
            (.listen()).

            This returns a (sock, src_addr) tuple:
                sock        - the newly created sock, corresponding to the incoming connection
                src_addr    - the remote address of the incoming connection
        """
        
        # prep the sockaddr that we will return
        cdef platform.sockaddr_storage ss
        cdef platform.socklen_t ss_len = sizeof(ss)

        # accept()
        cdef socket_t sock_fd = platform.accept(self.fd, <platform.sockaddr *> &ss, &ss_len)

        if sock_fd < 0 :
            raise_errno('accept')
        
        try :
            # prep the new socket
            sock_obj = socket(fd=sock_fd)

        except :
            # XXX: don't leak the socket fd? How does socket.__init__ handle this?
            platform.close(sock_fd)

            raise

        # prep the new addr
        cdef sockaddr src_addr = build_sockaddr(<platform.sockaddr *> &ss, ss_len)

        return sock_obj, src_addr
    
    def getsockname (self) :
        """
            Get the local address this socket is currently bound to. This can be set using bind(), or automatically.

            Returns a sockaddr object.
        """

        # prep the sockaddr that we will return
        cdef platform.sockaddr_storage ss
        cdef platform.socklen_t ss_len = sizeof(ss)

        # getsockname()
        if platform.getsockname(self.fd, <platform.sockaddr *> &ss, &ss_len) :
            raise_errno('getsockname')
        
        # build the new sockaddr
        return build_sockaddr(<platform.sockaddr *> &ss, ss_len)

    def getpeername (self) :
        """
            Get the remote address this socket is currently connected to.

            Returns a sockaddr object.
        """

        # prep the sockaddr that we will return
        cdef platform.sockaddr_storage ss
        cdef platform.socklen_t ss_len = sizeof(ss)

        # getpeername()
        if platform.getpeername(self.fd, <platform.sockaddr *> &ss, &ss_len) :
            raise_errno('getpeername')
        
        # build the new sockaddr
        return build_sockaddr(<platform.sockaddr *> &ss, ss_len)

    def send (self, object buf, int flags = 0) :
        """
            Transmit a message to the connected remote endpoint.

                buf         - the data to send
                flags       - (optional) MSG_* flags to send with

            Returns the number of bytes sent, which may be less than the length of buf.
        """

        cdef void *buf_ptr
        cdef size_t buf_len
        cdef libc.ssize_t ret

        parse_buf(&buf_ptr, &buf_len, buf, 0)

        # send()
        ret = platform.send(self.fd, buf_ptr, buf_len, flags)
        
        if ret < 0 :
            raise_errno('send')

        else :
            return ret

    def sendto (self, object buf, int flags = 0, sockaddr addr = None) :
        """
            Transmit a message to the given remote endpoint. If this socket is connected, the addr must not be
            specified, and this acts like send()

                buf         - the data to send
                flags       - (optional) MSG_* flags to send with
                addr        - (optional) target address

            Returns the number of bytes sent, which may be less than the length of buf.
        """

        cdef void *buf_ptr
        cdef size_t buf_len
        cdef libc.ssize_t ret
        
        cdef platform.sockaddr *sa
        cdef platform.socklen_t sa_len
       
        parse_sockaddr(&sa, &sa_len, addr, 1)
        parse_buf(&buf_ptr, &buf_len, buf, 0)
        
        # send()
        ret = platform.sendto(self.fd, buf_ptr, buf_len, flags, sa, sa_len)
        
        if ret < 0 :
            raise_errno('sendto')

        else :
            return ret
    
    def sendmsg (self, sockaddr addr = None, iov = None, control = None, int flags = 0) :
        """
            Transmit an extended message to the given remote endpoint (or default for connected sockets) with the given
            extra parameters.

                addr        - (optional) destination address (struct msghdr::msg_name)
                iov         - (optional) sequence of read-buffers to transmit
                control     - (optional) control message to transmit
                flags       - (optional) MSG_* flags to send with

            Returns the number of bytes sent, which may be less than the total length of iov.
        """
        
        cdef libc.ssize_t ret
        cdef libc.iovec *iovec
        cdef platform.msghdr msg
        
        libc.memset(&msg, 0, sizeof(msg))
        
        parse_sockaddr(<platform.sockaddr **> &msg.msg_name, &msg.msg_namelen, addr, 1)
        parse_buf(&msg.msg_control, &msg.msg_controllen, control, 1)
        
        # iov
        if iov :
            iov = tuple(iov)

            # numerb of bufs = number of iovecs
            msg.msg_iovlen = len(iov)
                
            # alloca the required number of iovec's
            msg.msg_iov = <libc.iovec *> libc.alloca(msg.msg_iovlen * sizeof(libc.iovec))
            
            # fill in the iovecs
            for i, buf in enumerate(iov) :
                iovec = &msg.msg_iov[i]

                parse_buf(&iovec.iov_base, &iovec.iov_len, buf, 1)
        
        # sendmsg()
        ret = platform.sendmsg(self.fd, &msg, flags)

        if ret < 0 :
            raise_errno('sendmsg')

        else :
            return ret
       
    def write (self, object buf) :
        """
            Write data to socket, mostly equivalent to send() with flags=0.

                buf         - the data to send

            Returns the number of bytes sent, which may be less than the length of buf.
        """
        
        cdef void *buf_ptr
        cdef size_t buf_len
        cdef libc.ssize_t ret

        parse_buf(&buf_ptr, &buf_len, buf, 0)

        # send()
        ret = libc.write(self.fd, buf_ptr, buf_len)
        
        if ret < 0 :
            raise_errno('write')

        else :
            return ret
    
    def writev (self, iov) :
        """
            Write data to a socket from multiple read-buffers.

                iov         - sequence of read-buffers to transmit
            
            Returns the number of bytes sent, which may be less than the total length of iov.
        """

        # iov
        cdef libc.iovec *iov_list = NULL
        cdef size_t iov_count = 0
        cdef libc.iovec *iovec
        
        iov = tuple(iov)

        # numerb of bufs = number of iovecs
        iov_count = len(iov)
            
        # alloca the required number of iovec's
        iov_list = <libc.iovec *> libc.alloca(iov_count * sizeof(libc.iovec))
        
        # fill in the iovecs
        for i, buf in enumerate(iov) :
            iovec = &iov_list[i]
                
            parse_buf(&iovec.iov_base, &iovec.iov_len, buf, 1)
            
        # sendmsg()
        ret = libc.writev(self.fd, iov_list, iov_count)

        if ret < 0 :
            raise_errno('writev')

        else :
            return ret
    
    def recv (self, size_t len, int flags = 0) :
        """
            Recieve a message, reading and returning at most `len` bytes.

                len         - size of buffer to use for recv
                flags       - (optional) MSG_* flags to use for recv()

            Returns the recieved data as a newly allocated string of the correct length.
        """

        cdef sockbuf sb
        cdef libc.ssize_t ret

        # alloc the recv buffer
        cdef char *buf = sockbuf_init(&sb, len)
        
        try :
            # recv()
            ret = platform.recv(self.fd, buf, len, flags)

            if ret < 0 :
                raise_errno('recv')

            # truncate to correct length
            return sockbuf_truncate(&sb, ret)

        finally :
            sockbuf_deinit(&sb)
    
    def recvfrom (self, size_t len, int flags = 0) :
        """
            Recieve a message, reading at most `len` bytes, also returning the source address.

                len         - size of buffer to use for recv
                flags       - (optional) MSG_* flags to use for recvfrom()
            
            Returns the recieved data and the source address as a (buf, src_addr) tuple :
                buf         - a newly allocated string containing the recieved data, of the correct length
                src_addr    - the source address the message was recieved from
        """

        cdef sockbuf sb
        cdef libc.ssize_t ret
        cdef object str
        cdef sockaddr src_addr
            
        # prep the sockaddr that we will return
        cdef platform.sockaddr_storage ss
        cdef platform.socklen_t ss_len = sizeof(ss)

        # alloc recv buf
        cdef char *buf = sockbuf_init(&sb, len)

        try :
            # recvfrom()
            ret = platform.recvfrom(self.fd, buf, len, flags, <platform.sockaddr *> &ss, &ss_len)

            if ret < 0 :
                raise_errno('recv')

            # truncate
            str = sockbuf_truncate(&sb, ret)
            
            # prep the new addr
            src_addr = build_sockaddr(<platform.sockaddr *> &ss, ss_len)
            
            return str, src_addr

        finally :
            sockbuf_deinit(&sb)
    
    def recvmsg (self, bint recv_addr = True, object iov_lens = None, size_t control_len = 0, int flags = 0) :
        """
            Recieve a message along with some extra data.

                recv_addr   - ask for and return a sockaddr for the source address of the message?
                iov_lens    - (optional) sequence of buffer sizes to use for the message data iov
                control_len - (optional) amount of auxiliary data to recieve
                flags       - (optional) flags to pass to recvmsg()

            Returns a (name, iovs, control, flags) tuple :
                name        - the source address of the message, or None
                iovs        - sequence of strings containing the recieved data corresponding to the iov_lens
                control     - string containing recieved control message, if any
                flags       - recieved flags
        """

        cdef platform.msghdr msg
        cdef sockbuf *sb_list, *sb, cmsg_sb
        cdef libc.iovec *iovec
        cdef size_t iov_len, i, msg_len
        cdef libc.ssize_t ret
        cdef sockaddr name = None 
        cdef object iovs = None
        cdef object control = None

        libc.memset(&msg, 0, sizeof(msg))
        libc.memset(&cmsg_sb, 0, sizeof(cmsg_sb))

        # prep the sockaddr that we will return
        cdef platform.sockaddr_storage ss
        
        # ask for a name?
        if recv_addr :
            msg.msg_name = <void *> &ss
            msg.msg_namelen = sizeof(ss)
        
        try :
            # build iov?
            if iov_lens :
                # stabilize
                iov_lens = tuple(iov_lens)
                
                msg.msg_iovlen = len(iov_lens)
                
                # alloc each iov plus a sockbuf for storing the PyString
                msg.msg_iov = <libc.iovec *> libc.alloca(msg.msg_iovlen * sizeof(libc.iovec))
                sb_list = <sockbuf *> libc.alloca(msg.msg_iovlen * sizeof(sockbuf))
                
                # zero out so we can cleanup
                libc.memset(sb_list, 0, msg.msg_iovlen * sizeof(sockbuf))
                
                # build each
                for i, iov_len in enumerate(iov_lens) :
                    # the associated iovec/sockbuf
                    iovec = &msg.msg_iov[i]
                    sb = &sb_list[i]
                    
                    # set up the sockbuf and iovec
                    iovec.iov_base = sockbuf_init(&sb_list[i], iov_len)
                    iovec.iov_len = iov_len

            # build control buffer?
            if control_len :
                msg.msg_control = sockbuf_init(&cmsg_sb, control_len)
                msg.msg_controllen = control_len

            # recvmsg()
            ret = platform.recvmsg(self.fd, &msg, flags)

            if ret < 0 :
                raise_errno('recvmsg')
            
            # name?
            if msg.msg_name and msg.msg_namelen :
                # build a sockaddr for the name
                name = build_sockaddr(<platform.sockaddr *> msg.msg_name, msg.msg_namelen)

            # iov?
            if ret :
                assert msg.msg_iov and msg.msg_iovlen
                
                iovs = []
                msg_len = ret
                i = 0
                
                # consume iov's until we have all the data we need
                while msg_len :
                    # sanity-check
                    assert i < msg.msg_iovlen

                    # get the associated iovec/sockbuf
                    iovec = &msg.msg_iov[i]
                    sb = &sb_list[i]
                    
                    # calc the size of this iov
                    # XXX: cdef
                    iov_len = min(msg_len, iovec.iov_len)

                    # add it as a string
                    iovs.append(sockbuf_truncate(sb, iov_len))

                    # advance
                    msg_len -= iov_len
                    i += 1

            # control?
            if msg.msg_control and msg.msg_controllen :
                # build the PyString for the control message
                control = sockbuf_truncate(&cmsg_sb, msg.msg_controllen)

            return name, iovs, control, msg.msg_flags

        finally :
            # cleanup
            sockbuf_deinit(&cmsg_sb)

            for i in range(msg.msg_iovlen) :
                sockbuf_deinit(&sb_list[i])

    def read (self, size_t len) :
        """
            Read data from a socket, mostly equivalent to a recv() with flags=0.

                len         - size of buffer to use for recv

            Returns the recieved data as a newly allocated string of the correct length.
        """

        cdef sockbuf sb
        cdef libc.ssize_t ret

        # alloc the recv buffer
        cdef char *buf = sockbuf_init(&sb, len)
        
        try :
            # recv()
            ret = libc.read(self.fd, buf, len)

            if ret < 0 :
                raise_errno('read')

            # truncate to correct length
            return sockbuf_truncate(&sb, ret)

        finally :
            sockbuf_deinit(&sb)
    
    def readv (self, object iov_lens) :
        """
            Read data from a socket into multiple buffers.

                iov_lens    - sequence of buffer sizes to use as iovs

            Returns a sequence of strings containing the recieved data corresponding to the iov_lens.

            XXX: implement using real readv instead of faking it with recvmsg...
        """
        
        # fake using recvmsg
        _, iovs, _, _ = self.recvmsg(recv_addr=False, iov_lens=iov_lens)

        return iovs

    def shutdown (self, how) :
        """
            Shutdown part of a full-duplex connection. 

                how         - one of SHUT_*

            This does not affect this socket's fd.
        """
        
        # shutdown()
        if platform.shutdown(self.fd, how) :
            raise_errno('shutdown')

    def close (self) :
        """
            Close the socket fd if we have one, invalidating it if succesful.

            Note that this will raise an error and keep the fd if the system close() returns an error.

            Calling this again after a succesfull close() does nothing.

            XXX: SO_LINGER/blocking?

            >>> s = socket()
            >>> s.fd >= 0
            True
            >>> s.close()
            >>> s.fd >= 0
            False
            >>> s.close()
        """
        
        # ignore if already closed
        if self.fd < 0 :
            return
        
        # close()
        if libc.close(self.fd) :
            raise_errno('close')
        
        # invalidate
        self.fd = -1
    
    def __dealloc__ (self) :
        """
            Close the socket fd if one is set, ignoring any errors from close
        """

        if self.fd >= 0 :
            if libc.close(self.fd) :
                # XXX: at least warn... ?
                pass