move address-family from tcp/socket interface to endpoint interface. The address family of a socket is strictly a property of the address passed to it
authorTero Marttila <terom@fixme.fi>
Sun, 23 Aug 2009 22:59:40 +0300
changeset 37 14db3fe42b6c
parent 36 4d5c02fe9c27
child 38 f0fc793a3754
move address-family from tcp/socket interface to endpoint interface. The address family of a socket is strictly a property of the address passed to it
qmsk/net/transport/endpoint.py
qmsk/net/transport/socket.py
qmsk/net/transport/tcp.py
test/transport_tcp.py
--- a/qmsk/net/transport/endpoint.py	Sun Aug 23 22:43:39 2009 +0300
+++ b/qmsk/net/transport/endpoint.py	Sun Aug 23 22:59:40 2009 +0300
@@ -11,23 +11,40 @@
         Abstract network address interface, 
     """
 
-    def resolve (self, family, socktype, protocol = 0, passive = True) :
+    def resolve (self, socktype, protocol = 0, passive = True) :
         """
-            Translate this Endpoint into a sequence of AddrInfo objects for the given family/socktype.
+            Translate this Endpoint into a sequence of AddrInfo objects for the given socktype.
         """
         
         raise NotImplemented()
 
 
-class InetAddr (Endpoint, address.endpoint) :
+class InetAddr (Endpoint) :
     """
         An internet address, either IPv4 or IPv6.
 
            hostname     - [str] literal address or DNS hostname
            service      - [str] port number or service name
+           family       - AF_* associated with this address
     """
+
+    def __init__ (self, hostname=None, service=None, family=constants.AF_UNSPEC) :
+        """
+            Initialize with given parameters, but doesn't perform any lookups yet.
+        """
+
+        self.endpoint = address.endpoint(hostname, service)
+        self.family = family
     
-    def resolve (self, family, socktype, protocol = 0, passive = True) :
+    @property
+    def hostname (self) :
+        return self.endpoint.hostname
+
+    @property
+    def service (self) :
+        return self.endpoint.service
+
+    def resolve (self, socktype, protocol = 0, passive = True) :
         """
             Resolve using getaddrinfo
         """
@@ -37,7 +54,7 @@
         if passive :
             flags |= constants.AI_PASSIVE
 
-        return self.getaddrinfo(family, socktype, protocol, flags)
+        return self.endpoint.getaddrinfo(self.family, socktype, protocol, flags)
 
 class UnixAddr (Endpoint) :
     """
@@ -48,10 +65,7 @@
         self.path = path
         self.addr = af_unix.sockaddr_un(path)
 
-    def resolve (self, family, socktype, protocol = 0, passive = True) :
-        if family != constants.AF_UNIX :
-            raise ValueError("Address family mismatch: %s" % (family, ))
-
+    def resolve (self, socktype, protocol = 0, passive = True) :
         if not socktype :
             raise ValueError("Unknown socktype: %s" % (socktype, ))
 
@@ -74,18 +88,11 @@
         self.family = family
         self.socktype = socktype
 
-    def resolve (self, family, socktype, protocol = 0, passive = True) :
+    def resolve (self, socktype, protocol = 0, passive = True) :
         """
             Returns a single AddrInfo object representing this address 
         """
     
-        if family and family != self.family :
-            raise ValueError("Address family mismatch: %s should be %s" % (family, self.family))
-        
-        elif not family :
-            family = self.family
-
-
         if socktype and self.socktype and socktype != self.socktype :
             raise ValueError("Socket type mismatch: %s should be %s" % (socktype, self.socktype))
 
@@ -97,5 +104,5 @@
                 raise ValueError("Socket type unknown")
 
 
-        return [AddrInfo(0, family, socktype, protocol, self.addr, None)]
+        return [AddrInfo(0, self.family, socktype, protocol, self.addr, None)]
 
--- a/qmsk/net/transport/socket.py	Sun Aug 23 22:43:39 2009 +0300
+++ b/qmsk/net/transport/socket.py	Sun Aug 23 22:59:40 2009 +0300
@@ -69,12 +69,18 @@
         Common operations for Client/Service
     """
 
+    # default socktype
+    _SOCKTYPE = 0
+
     @classmethod
-    def _socket (self, family, socktype, protocol = 0) :
+    def _socket (cls, family, socktype = None, protocol = 0) :
         """
             Construct and return a new socket object using the given parameters.
         """
 
+        if socktype is None :
+            socktype = cls._SOCKTYPE
+
         return socket.socket(family, socktype, protocol)
 
     @classmethod
@@ -100,13 +106,12 @@
             return sock
 
     @classmethod
-    def _bind_endpoint (cls, endpoint, family, socktype, protocol=0) :
+    def _bind_endpoint (cls, endpoint, socktype = None, protocol=0) :
         """
             This will resolve the given endpoint, and attempt to create and bind a suitable socket and return it.
 
                 endpoint        - local Endpoint to bind() to.
-                family          - socket address family to use.
-                socktype        - socket type to use
+                socktype        - (optiona) socket type to use, defaults to _SOCKTYPE
                 protocol        - (optional) specific protocol
 
             Raises a ServiceBindError if this is unable to create a bound socket.
@@ -116,8 +121,11 @@
         
         errors = []
         
+        if socktype is None :
+            socktype = cls._SOCKTYPE
+        
         # resolve the endpoint and try socket+bind
-        for ai in endpoint.resolve(family, socktype, protocol, AI_PASSIVE) :
+        for ai in endpoint.resolve(socktype, protocol, AI_PASSIVE) :
             try :
                 # try to socket+bind this addrinfo
                 sock = cls._bind_addrinfo(ai)
@@ -137,23 +145,23 @@
             # no suitable address found :(
             raise SocketBindEndpointError(endpoint, errors)
     
-    def _init_endpoint (self, endpoint, family, socktype, protocol = 0) :
+    def _init_endpoint (self, endpoint, socktype = None, protocol = 0, family = None) :
         """
             Initialize this socket by constructing a new socket with the given parameters, bound to the given endpoint,
             if given. If no endpoint is given, this simply creates a socket with the given settings and does not bind
             it anywhere.
         """
-
-        # create local socket
-        if endpoint :
+        
+        if endpoint is not None :
             # create a suitable socket bound to a the given endpoint
-            self.sock = self._bind_endpoint(endpoint, family, socktype, protocol)
-
+            self.sock = self._bind_endpoint(endpoint, socktype, protocol)
+        
         else :
-            # create a suitable socket not bound to anything
+            assert family
+
+            # simply create a socket
             self.sock = self._socket(family, socktype, protocol)
 
-
 class Service (Common) :
     """
         Listener socket
@@ -252,15 +260,18 @@
         return sock
 
     @classmethod
-    def _connect_sock_endpoint (cls, sock, endpoint, family, socktype, protocol = 0) :
+    def _connect_sock_endpoint (cls, sock, endpoint, socktype = None, protocol = 0) :
         """
             Connect this socket to the given remote endpoint, using the given parameters to resolve the endpoint.
         """
         
         errors = []
         
+        if socktype is None :
+            socktype = cls._SOCKTYPE
+
         # resolve the endpoint and try socket+bind
-        for ai in endpoint.resolve(family, socktype, protocol) :
+        for ai in endpoint.resolve(socktype, protocol) :
             try :
                 # try to connect the socket to this addrinfo
                 cls._connect_sock_addrinfo(sock, ai)
@@ -281,7 +292,7 @@
             raise SocketConnectEndpointError(endpoint, errors)
     
     @classmethod
-    def _connect_endpoint (cls, endpoint, family, socktype, protocol = 0) :
+    def _connect_endpoint (cls, endpoint, socktype = None, protocol = 0) :
         """
             Create a new socket and connect it to the given remote endpoint, using the given parameters to resolve the
             endpoint.
@@ -289,8 +300,11 @@
 
         errors = []
         
+        if socktype is None :
+            socktype = cls._SOCKTYPE
+
         # resolve the endpoint and try socket+bind
-        for ai in endpoint.resolve(family, socktype, protocol) :
+        for ai in endpoint.resolve(socktype, protocol) :
             try :
                 # try to socket+connect this addrinfo
                 sock = cls._connect_addrinfo(ai)
@@ -310,7 +324,7 @@
             # no suitable address found :(
             raise SocketConnectEndpointError(endpoint, errors)
 
-    def _init_connect_endpoint (self, endpoint, family, socktype, protocol = 0):
+    def _init_connect_endpoint (self, endpoint, socktype = None, protocol = 0):
         """
             If we already have an existing socket, connect it to the given endpoint, otherwise try and connect to the
             given endpoint with a new socket.
@@ -325,11 +339,11 @@
 
         if self.socket :
             # connect with existing socket
-            self._connect_sock_endpoint(self.socket, endpoint, family, socktype, protocol)
+            self._connect_sock_endpoint(self.socket, endpoint, socktype, protocol)
 
         else :
             # connect with new socket
-            self._connect_endpoint(endpoint, family, socktype, protocol)
+            self._connect_endpoint(endpoint, socktype, protocol)
 
 class Stream (Base) :
     """
--- a/qmsk/net/transport/tcp.py	Sun Aug 23 22:43:39 2009 +0300
+++ b/qmsk/net/transport/tcp.py	Sun Aug 23 22:59:40 2009 +0300
@@ -33,22 +33,27 @@
     """
         An implementation of Service for TCP sockets.
     """
+    
+    _SOCKTYPE = socket.SOCK_STREAM
 
-    def __init__ (self, endpoint, af=socket.AF_UNSPEC, listen_backlog=LISTEN_BACKLOG) :
+    def __init__ (self, endpoint, listen_backlog=LISTEN_BACKLOG, family=None) :
         """
             Construct a service, bound to the given local endpoint and listening for incoming connections using the
             given backlog.
 
                 endpoint        - local Endpoint to bind() to. Usually, it is enough to just specify the port.
-                af              - the socket address family to use (one of qmsk.net.socket.constants.AF_*)
                 listen_backlog  - backlog length argument to use for socket.listen()
+                family          - (optional) address family to use if no endpoint is given
             
+            Note that as a special case, it is possible to construct a service without an Endpoint (i.e. None).
+            In this case, there will be no socket.bind() call, instead, a socket is created with the given address
+            family (which *MUST* be given), and .listen() causes the OS to pick a local address to use.
 
             This will raise an error if the bind() or listen() operations fail.
         """
         
         # construct a suitable socket bound to the given endpoint
-        self._init_endpoint(endpoint, af, socket.SOCK_STREAM)
+        self._init_endpoint(endpoint, family=family)
 
         # make us listen
         self._listen(listen_backlog)
@@ -82,19 +87,19 @@
         An implementation of Client for TCP sockets.
     """
 
-    def __init__ (self, connect_endpoint, af=socket.AF_UNSPEC, bind_endpoint=None) :
+    _SOCKTYPE = socket.SOCK_STREAM
+
+    def __init__ (self, connect_endpoint, bind_endpoint=None) :
         """
             Construct a client, connecting to the given remote endpoint.
 
                 connect_endpoint    - remote Endpoint to connect() to.
-                af                  - socket address family to use (one of qmsk.net.socket.constants.AF_*)
                 bind_endpoint       - (optional) local Endpoint to bind() to before connecting.
 
         """
 
         # store
         self.connect_endpoint = connect_endpoint
-        self.af = af
         self.bind_endpoint = bind_endpoint
 
     def connect (self, cls=Connection) :
@@ -104,14 +109,14 @@
 
         if self.bind_endpoint :
             # construct a suitable local socket, bound to a specific endpoint
-            sock = self._bind_endpoint(self.bind_endpoint, self.af, socket.SOCK_STREAM)
+            sock = self._bind_endpoint(self.bind_endpoint)
 
             # connect it to the remote endpoint
-            self._connect_sock_endpoint(sock, self.connect_endpoint, self.af, socket.SOCK_STREAM)
+            self._connect_sock_endpoint(sock, self.connect_endpoint)
 
         else :
             # let _init_connect_endpoint pick a socket to use
-            sock = self._connect_endpoint(self.connect_endpoint, self.af, socket.SOCK_STREAM)
+            sock = self._connect_endpoint(self.connect_endpoint)
         
         # construct
         return cls(sock)
--- a/test/transport_tcp.py	Sun Aug 23 22:43:39 2009 +0300
+++ b/test/transport_tcp.py	Sun Aug 23 22:59:40 2009 +0300
@@ -6,7 +6,7 @@
 class TestService (unittest.TestCase) :
     def setUp (self) :
         # create service on random port
-        self.ss = tcp.Service(None, af=socket.AF_INET)
+        self.ss = tcp.Service(None, family=socket.AF_INET)
 
         self.addr = self.ss.sock.getsockname()
     
@@ -35,18 +35,12 @@
         self.sockaddr = af_inet.sockaddr_in('127.0.0.1', self.ls.getsockname().port)
         self.addr = endpoint.SockAddr(self.sockaddr)
    
-    def test_connect_unspec (self) :
+    def test_connect (self) :
         cc = tcp.Client(self.addr)
         cs = cc.connect()
 
         self.assertEquals(cs.sock.getpeername(), self.sockaddr)
 
-    def test_connect_inet4 (self) :
-        cc = tcp.Client(self.addr, socket.AF_INET)
-        cs = cc.connect()
-
-        self.assertEquals(cs.sock.getpeername(), self.sockaddr)
-
     def test_connect_bind (self) :
         sockaddr = af_inet.sockaddr_in('127.0.0.1', self.sockaddr.port + 1)
         
@@ -67,7 +61,7 @@
         self.assertRaises(socket.SocketConnectEndpointError, cc.connect)
     
     def test_connect_inet (self) :
-        cc = tcp.Client(endpoint.InetAddr('localhost', self.sockaddr.port), socket.AF_INET)
+        cc = tcp.Client(endpoint.InetAddr('localhost', self.sockaddr.port))
         cs = cc.connect()
 
         self.assertEquals(cs.sock.getpeername(), self.sockaddr)