qmsk/net/socket/address.pyx
author Tero Marttila <terom@fixme.fi>
Mon, 17 Aug 2009 20:24:12 +0300
changeset 22 f6e8d5e37998
parent 20 0e4933d5862e
child 23 15d8bb96b8d4
permissions -rw-r--r--
some __cmp__ and doctest love for socket.address
cimport qmsk.net.libc as libc
cimport qmsk.net.py as py

from qmsk.net.socket.address cimport *

cimport qmsk.net.socket.platform as platform

cdef class sockaddr :
    cdef void _init_family (self, platform.sa_family_t family=platform.AF_UNSPEC) :
        self.family = family

    # XXX:use size_t
    cdef int _get_sockaddr (self, platform.sockaddr **sa_ptr, platform.socklen_t *sa_len) except -1 :
        """
            Get the sockaddr pointer and sockaddr length for this address
        """ 

        raise NotImplementedError()

    cdef platform.sockaddr* _get_sockaddr_ptr (self) except NULL :
        """
            Get the sockaddr pointer
        """

        cdef platform.sockaddr *sa
        cdef platform.socklen_t sa_len

        self._get_sockaddr(&sa, &sa_len)

        return sa
    
    cdef platform.socklen_t _get_sockaddr_len (self) except -1 :
        """
            Get the sockaddr len
        """

        cdef platform.sockaddr *sa
        cdef platform.socklen_t sa_len

        self._get_sockaddr(&sa, &sa_len)

        return sa_len

    cdef int _set_sockaddr (self, platform.sockaddr *sa, size_t sa_len) except -1 :
        """
            Set the sockaddr value for this address; sa_len must match!
        """

        raise NotImplementedError()

    def getnameinfo (self) :
        """
            Returns a (host, serv) tuple for this address à la getnameinfo
        """

        cdef platform.sockaddr *sa
        cdef platform.socklen_t sa_len

        # XXX: take as args?
        cdef int flags = platform.NI_NUMERICHOST | platform.NI_NUMERICSERV
        
        # get our abstract sockaddr
        self._get_sockaddr(&sa, &sa_len)

        # get nice text format
        return platform.getnameinfo(sa, sa_len, flags)

    property addr :
        """
            The ASCII literal network address
        """

        def __get__ (self) :
            """
                Default implmentation using getnameinfo()
            """

            addr, port = self.getnameinfo()

            return addr

    property port :
        """
            The integer port number
        """

        def __get__ (self) :
            """
                Default implementation using getnameinfo() and int()
            """

            addr, port = self.getnameinfo()

            return int(port)
    
    def __repr__ (self) :
        return "sockaddr(%d, %s, %d)" % (self.family, self.addr, self.port)

cdef class sockaddr_in (sockaddr) :
    """
        AF_INET struct sockaddr_in

        >>> sa = sockaddr_in("127.0.0.1", 80)
        >>> sa.addr
        '127.0.0.1'
        >>> sa.port
        80
        >>> str(sa)
        '127.0.0.1:80'

        >>> sockaddr_in('2001::5')
        Traceback (most recent call last):
          ...
        NameError: Invalid network address for specified address family: '2001::5'

    """

    # the struct sockaddr_in
    cdef platform.sockaddr_in sockaddr

    def __init__ (self, object addr=None, platform.in_port_t port=0) :
        """
            Construct using given literal IPv4 address and TCP/UDP port

                addr        - IPv4 address, defaults to INADDR_ANY (0.0.0.0)
                port        - TCP/UDP port, defaults to 0 (ephemeral)
        """

        # zero
        libc.memset(&self.sockaddr, 0, sizeof(self.sockaddr))

        # store our family
        # XXX: this should be a class attribute...
        self._init_family(platform.AF_INET)

        # constant af
        self.sockaddr.sin_family = self.family
        
        # set the sin_port
        self.sockaddr.sin_port = platform.htons(port)
        
        if addr :
            # set the sin_addr
            # this automatically converts the addr from str -> char *
            platform.inet_pton(self.family, addr, &self.sockaddr.sin_addr)

        else :
            # set as INADDR_ANY
            self.sockaddr.sin_addr.s_addr = platform.INADDR_ANY
    
    cdef int _get_sockaddr (self, platform.sockaddr **sa_ptr, platform.socklen_t *sa_len) except -1 :
        if sa_ptr :
            sa_ptr[0] = <platform.sockaddr *> &self.sockaddr

        if sa_len :
            sa_len[0] = sizeof(self.sockaddr)

        return 0

    cdef int _set_sockaddr (self, platform.sockaddr *sa, size_t sa_len) except -1 :
        assert sa_len == sizeof(self.sockaddr)

        libc.memcpy(&self.sockaddr, sa, sa_len)

    property port :
        """
            The integer port number
        """

        def __get__ (self) :
            return platform.ntohs(self.sockaddr.sin_port)

    def __cmp__ (self, other_obj) :
        """
            A sockaddr_in is equal to any other sockaddr_in which has the same addr and port
            
            >>> assert sockaddr_in() == sockaddr_in()
            >>> assert sockaddr_in('127.0.0.1', 80) == sockaddr_in('127.0.0.1', 80)
            >>> addr = sockaddr_in(); assert addr == addr
        """

        if not isinstance(other_obj, sockaddr_in) :
            return <object> py.Py_NotImplemented

        cdef sockaddr_in other = other_obj
        cdef platform.sockaddr_in *sa1 = &self.sockaddr, *sa2 = &other.sockaddr

        if other is self :
            return 0

        return (
                libc.memcmp(<void *> &sa1.sin_port, <void *> &sa2.sin_port, sizeof(sa1.sin_port))
            or  libc.memcmp(<void *> &sa1.sin_addr, <void *> &sa2.sin_addr, sizeof(sa1.sin_addr))
        )

    def __str__ (self) :
        """
            Return the literal ASCII representation for this sockaddr as an '<addr>:<port> string
        
            >>> str(sockaddr_in())
            '0.0.0.0:0'
        """
        
        # format
        return "%s:%s" % self.getnameinfo()

cdef class sockaddr_in6 (sockaddr) :
    """
        AF_INET6 struct sockaddr_in6

        >>> sa6 = sockaddr_in6("::1", 80)
        >>> sa6.addr
        '::1'
        >>> sa6.port
        80
        >>> str(sa6)
        '[::1]:80'
        
    """

    cdef platform.sockaddr_in6 sockaddr

    def __init__ (self, object addr=None, platform.in_port_t port=0, unsigned int scope_id = 0) :
        """
            Construct using given literal IPv6 address and TCP/UDP port

                addr        - IPv6 address, defaults to platform.in6addr_any (::)
                port        - TCP/UDP port, defaults to 0 (ephemeral)
                scope_id    - (optional) scope ID representing interface index for link-local addresses
        """

        # zero
        libc.memset(&self.sockaddr, 0, sizeof(self.sockaddr))

        # store our family
        # XXX: this should be a class attribute...
        self._init_family(platform.AF_INET6)

        # constant af
        self.sockaddr.sin6_family = self.family
        
        # set the sin_port
        self.sockaddr.sin6_port = platform.htons(port)
        
        if addr :
            # set the sin_addr
            # this automatically converts the addr from str -> char *
            platform.inet_pton(self.family, addr, &self.sockaddr.sin6_addr)

        else :
            # set as INADDR_ANY
            self.sockaddr.sin6_addr = platform.in6addr_any

        # scope ID
        self.sockaddr.sin6_scope_id = scope_id

    cdef int _get_sockaddr (self, platform.sockaddr **sa_ptr, platform.socklen_t *sa_len) except -1 :
        if sa_ptr :
            sa_ptr[0] = <platform.sockaddr *> &self.sockaddr

        if sa_len :
            sa_len[0] = sizeof(self.sockaddr)

        return 0

    cdef int _set_sockaddr (self, platform.sockaddr *sa, size_t sa_len) except -1 :
        assert sa_len == sizeof(self.sockaddr)

        libc.memcpy(&self.sockaddr, sa, sa_len)

    property port :
        """
            The integer port number.

            This will represent it correctly in host byte order.
        """

        def __get__ (self) :
            return platform.ntohs(self.sockaddr.sin6_port)


    property flowinfo :
        """
            The integer flowinfo

            XXX: byteorder?
        """

        def __get__ (self) :
            return self.sockaddr.sin6_flowinfo


    property scope_id :
        """
            The scope ID - corresponds to an interface index for link-scope addresses.

            This should be in host byte order...
        """

        def __get__ (self) :
            return self.sockaddr.sin6_scope_id

    def __cmp__ (self, other_obj) :
        """
            A sockaddr_in6 is equal to any other sockaddr_in6 which has the same addr, port and scope ID.

            XXX: flowinfo?

            XXX: A sockaddr_in6 is also equal to a sockaddr_in if the sockaddr_in6 represents the given v4-mapped address.
            
            >>> assert sockaddr_in6() == sockaddr_in6()
            >>> assert sockaddr_in6('0:0:0::1', 80) == sockaddr_in6('::1', 80)
            >>> assert sockaddr_in6('::127.0.0.1') == sockaddr_in('127.0.0.1')
        """

        if not isinstance(other_obj, sockaddr_in6) :
            return <object> py.Py_NotImplemented

        cdef sockaddr_in6 other = other_obj
        cdef platform.sockaddr_in6 *sa1 = &self.sockaddr, *sa2 = &other.sockaddr

        if other is self :
            return 0

        return (
                libc.memcmp(<void *> &sa1.sin6_port,        <void *> &sa2.sin6_port,        sizeof(sa1.sin6_port))
            or  libc.memcmp(<void *> &sa1.sin6_addr,        <void *> &sa2.sin6_addr,        sizeof(sa1.sin6_addr))
            or  libc.memcmp(<void *> &sa1.sin6_scope_id,    <void *> &sa2.sin6_scope_id,    sizeof(sa1.sin6_scope_id))
        )

    def __str__ (self) :
        """
            Return the literal ASCII representation for this sockaddr as a '[<addr>]:<port> string

            >>> str(sockaddr_in6())
            '[::]:0'

            >>> str(sockaddr_in6('2001:0::05:1'))
            '[2001::5:1]'

            >>> str(sockaddr_in6('fe80::abcd', scope_id=5))
            '[fe80::abcd%5]'
        """

        addr, port = self.getnameinfo()
        scope_id = self.scope_id

        # format with scope_id
        return "[%s%s]:%s" % (
            addr,
            "%%%d" % scope_id if scope_id else "",
            port
        )

# mapping of AF -> sockaddr, user-modifyable
SOCKADDR_BY_FAMILY = {
    platform.AF_INET:   sockaddr_in,
    platform.AF_INET6:  sockaddr_in6,
}

# build a sockaddr from the given sockaddr struct, based on sa_family
cdef sockaddr build_sockaddr (platform.sockaddr *sa, size_t sa_len) :
    # lookup correct class to use
    addr_type = SOCKADDR_BY_FAMILY[sa.sa_family]
    
    # construct with defaults
    cdef sockaddr addr = addr_type()

    # store
    addr._set_sockaddr(sa, sa_len)

    return addr

cdef class addrinfo :
    
    cdef _init_addrinfo (self, platform.addrinfo *ai) :
        #ai.flags = c_ai.ai_flags
        self.family = ai.ai_family
        self.socktype = ai.ai_socktype
        self.protocol = ai.ai_protocol
        self.addr = build_sockaddr(ai.ai_addr, ai.ai_addrlen)
        self.canonname = ai.ai_canonname if ai.ai_canonname else None

    def __str__ (self) :
        return "family=%d, socktype=%d, protocol=%d, addr=%s, canonname=%s" % (self.family, self.socktype, self.protocol, self.addr, self.canonname)

cdef addrinfo build_addrinfo (platform.addrinfo *c_ai) :
    cdef addrinfo ai = addrinfo()
    
    ai._init_addrinfo(c_ai)

    return ai

cdef class endpoint :

    def __init__ (self, hostname=None, service=None) :
        """
            Construct with the given hostname/service, either of which may be None.

            A hostname of None implies all valid local addresses (with AI_PASSIVE), and a service of None implies an
            ephemeral port.

                hostname        - the literal address or DNS hostname or anything else that GAI supports
                service         - the numeric port or service name
        """

        self.hostname = str(hostname)
        self.service = str(service)

    cpdef getaddrinfo (self, int family, int socktype, int protocol = 0, int flags = platform.AI_PASSIVE) :
        """
            Look up our hostname/service using the given socket parameters, and return a sequence of addrinfo objects.
        """
        
        # XXX: Cython doesn't support proper compound value literals...
        cdef platform.addrinfo hints
        
        libc.memset(&hints, 0, sizeof(hints))
        hints.ai_flags          = flags
        hints.ai_family         = family
        hints.ai_socktype       = socktype
        hints.ai_protocol       = protocol

        cdef platform.addrinfo *res, *r
        cdef int err
        cdef object ret = []

        cdef char *hostname = NULL
        cdef char *service = NULL

        if self.hostname is not None :
            hostname = self.hostname
        
        if self.service is not None :
            service = self.service

        # operate!
        err = platform.c_getaddrinfo(hostname, service, &hints, &res)

        try :
            if err :
                # XXX: raise a GAIError
                raise Exception(platform.gai_strerror(err))
            
            # gather results
            r = res

            while r :
                ret.append(build_addrinfo(r))

                r = r.ai_next
            
            # ok
            return ret

        finally :
            platform.c_freeaddrinfo(res)

    def __str__ (self) :
        return "hostname=%s, service=%s" % (self.hostname, self.service)