pvl.hosts: improve HostExtension support enough to move boot= into pvl.hosts.dhcp
authorTero Marttila <terom@paivola.fi>
Tue, 10 Mar 2015 00:11:43 +0200
changeset 739 5149c39f3dfc
parent 738 3104fdf7ea26
child 740 74352351d6f5
pvl.hosts: improve HostExtension support enough to move boot= into pvl.hosts.dhcp
pvl/hosts/__init__.py
pvl/hosts/dhcp.py
pvl/hosts/host.py
pvl/hosts/interface.py
pvl/hosts/tests.py
--- a/pvl/hosts/__init__.py	Mon Mar 09 23:31:13 2015 +0200
+++ b/pvl/hosts/__init__.py	Tue Mar 10 00:11:43 2015 +0200
@@ -7,5 +7,4 @@
 from pvl.hosts.host import (
         HostError,
         Host,
-        extension,
 )
--- a/pvl/hosts/dhcp.py	Mon Mar 09 23:31:13 2015 +0200
+++ b/pvl/hosts/dhcp.py	Tue Mar 10 00:11:43 2015 +0200
@@ -1,6 +1,93 @@
 import pvl.dhcp.config
 import pvl.hosts.host
 
+def parse_dhcp_boot(boot):
+    """
+        Parse the dhcp boot=... option
+
+        >>> print parse_dhcp_boot(None)
+        {}
+        >>> print parse_dhcp_boot({'filename': '/foo'})
+        {'filename': '/foo'}
+        >>> print parse_dhcp_boot({'filename': '/foo', 'next-server': 'bar'})
+        {'next-server': 'bar', 'filename': '/foo'}
+        >>> print parse_dhcp_boot('/foo')
+        {'filename': '/foo'}
+        >>> print parse_dhcp_boot('bar:/foo')
+        {'next-server': 'bar', 'filename': '/foo'}
+        >>> print parse_dhcp_boot('bar:')
+        {'next-server': 'bar'}
+        >>> print parse_dhcp_boot('foo')
+        Traceback (most recent call last):
+            ...
+        ValueError: invalid boot=foo
+    """
+    
+    # unpack dict, or str
+    if not boot:
+        filename = next_server = None
+        boot_str = None
+
+    elif isinstance(boot, dict):
+        filename = boot.pop('filename', None)
+        next_server = boot.pop('next-server', None)
+        boot_str = boot.pop(None, None)
+
+    else:
+        filename = next_server = None
+        boot_str = boot
+        boot = None
+        
+    if boot:
+        raise ValueError("Invalid boot.*: {instances}".format(instances=' '.join(boot)))
+    
+    # any boot= given overrides boot.* fields
+    if not boot_str:
+        pass
+
+    elif boot_str.startswith('/'):
+        filename = boot_str
+
+    elif boot_str.endswith(':'):
+        next_server = boot_str[:-1]
+
+    elif ':' in boot_str:
+        next_server, filename = boot_str.split(':', 1)
+
+    else :
+        raise ValueError("invalid boot={boot}".format(boot=boot_str))
+    
+    return next_server, filename
+
+@pvl.hosts.host.register_extension
+class HostDHCP(pvl.hosts.host.HostExtension):
+    EXTENSION = 'dhcp'
+    EXTENSION_FIELDS = (
+            'boot',
+    )
+
+    @classmethod
+    def build (cls,
+            boot        = None,
+            subclass    = None,
+    ):
+        next_server, filename = parse_dhcp_boot(boot)
+
+        return cls(
+                filename    = filename,
+                subclass    = subclass,
+                next_server = next_server,
+        )
+
+    def __init__(self,
+            filename    = None,
+            next_server = None,
+            subclass    = None,
+    ):
+        self.filename = filename
+        self.next_server = next_server
+        self.subclass = subclass
+
 def dhcp_host_subclass (host, subclass, ethernet):
     """
         Build a DHCP Item for declaring a subclass for a host.
@@ -16,7 +103,7 @@
 class HostDHCPError(pvl.hosts.host.HostError):
     pass
 
-def dhcp_host_options (host, ethernet, subclass=None):
+def dhcp_host_options (host, ethernet, dhcp=None):
     """
         Yield specific dhcp.conf host { ... } items.
     """
@@ -26,20 +113,17 @@
 
     if host.ip4:
         yield 'fixed-address', pvl.dhcp.config.Field(str(host.ip4))
-      
-    for bootopt in ('next-server', 'filename'):
-        if bootopt in host.boot:
-            yield bootopt, host.boot[bootopt]
+    
+    if dhcp:
+        if dhcp.next_server:
+            yield 'next-server', dhcp.next_server
 
-def dhcp_host (host,
-        subclass    = None,
-):
+        if dhcp.filename:
+            yield 'filename', dhcp.filename
+
+def dhcp_host (host):
     """
-        Yield pvl.dhcp.config.Block's
-
-        Takes dhcp:* extensions as keyword arguments
-
-            subclass: name      - generate a subclass name $ethernet for this host
+        Yield pvl.dhcp.config.Block's for given Host, with possible HostDHCP extensions.
     """
 
     if not host.ethernet:
@@ -50,6 +134,8 @@
         comment = u"Owner: {host.owner}".format(host=host)
     else:
         comment = None
+    
+    dhcp = host.extensions.get('dhcp')
 
     for index, ethernet in host.ethernet.iteritems() :
         if index:
@@ -57,12 +143,12 @@
         else:
             name = '{host.name}'.format(host=host)
 
-        items = list(dhcp_host_options(host, ethernet))
+        items = list(dhcp_host_options(host, ethernet, dhcp=dhcp))
 
         yield pvl.dhcp.config.Block(('host', name), items, comment=comment)
 
-        if subclass:
-            yield dhcp_host_subclass(host, subclass, ethernet)
+        if dhcp and dhcp.subclass:
+            yield dhcp_host_subclass(host, dhcp.subclass, ethernet)
     
 def dhcp_hosts (hosts):
     """
@@ -73,9 +159,7 @@
     blocks = { }
 
     for host in hosts:
-        extensions = host.extensions.get('dhcp', {})
-
-        for block in dhcp_host(host, **extensions):
+        for block in dhcp_host(host):
             if not block.key:
                 # TODO: check for unique Item-Blocks
                 pass
--- a/pvl/hosts/host.py	Mon Mar 09 23:31:13 2015 +0200
+++ b/pvl/hosts/host.py	Tue Mar 10 00:11:43 2015 +0200
@@ -69,59 +69,6 @@
 
     return ':'.join('%02x' % int(x, 16) for x in value.split(':'))
 
-def parse_dhcp_boot(boot):
-    """
-        Parse the dhcp boot=... option
-
-        >>> print parse_dhcp_boot(None)
-        {}
-        >>> print parse_dhcp_boot({'filename': '/foo'})
-        {'filename': '/foo'}
-        >>> print parse_dhcp_boot({'filename': '/foo', 'next-server': 'bar'})
-        {'next-server': 'bar', 'filename': '/foo'}
-        >>> print parse_dhcp_boot('/foo')
-        {'filename': '/foo'}
-        >>> print parse_dhcp_boot('bar:/foo')
-        {'next-server': 'bar', 'filename': '/foo'}
-        >>> print parse_dhcp_boot('bar:')
-        {'next-server': 'bar'}
-        >>> print parse_dhcp_boot('foo')
-        Traceback (most recent call last):
-            ...
-        ValueError: invalid boot=foo
-    """
-    
-    # normalize to dict
-    if not boot:
-        boot = { }
-    elif not isinstance(boot, dict):
-        boot = { None: boot }
-    else:
-        boot = dict(boot)
-    
-    # support either an instanced dict or a plain str or a mixed instanced-with-plain-str
-    boot_str = boot.pop(None, None)
-
-    if not (set(boot) <= set(('filename', 'next-server', None))):
-        raise ValueError("Invalid boot.*: {instances}".format(instances=' '.join(boot)))
-
-    # any boot= given overrides boot.* fields
-    if not boot_str:
-        pass
-    elif boot_str.startswith('/'):
-        boot['filename'] = boot_str
-
-    elif boot_str.endswith(':'):
-        boot['next-server'] = boot_str[:-1]
-
-    elif ':' in boot_str:
-        boot['next-server'], boot['filename'] = boot_str.split(':', 1)
-
-    else :
-        raise ValueError("invalid boot={boot}".format(boot=boot_str))
-    
-    return boot
-
 def parse_str(value):
     """
         Normalize optional string value.
@@ -155,16 +102,32 @@
     """
 
     EXTENSIONS = { }
+    EXTENSION_FIELDS = { }
 
     @classmethod
-    def build_extensions(cls, extensions):
-        for extension, value in extensions.iteritems():
+    def parse_extensions(cls, extensions, extra):
+        """
+            Parse extensions and extension fields to yield
+                (HostExtension, field, value)
+        """
+
+        for extension, values in extensions.iteritems():
             extension_cls = cls.EXTENSIONS.get(extension)
 
-            if extension_cls:
-                yield extension, extension_cls.build(**value)
-            else:
+            if not extension_cls:
                 log.warning("skip unknown extension: %s", extension)
+                continue
+            
+            for field, value in values.iteritems():
+                yield extension_cls, field, value
+
+        for field, value in extra.iteritems():
+            extension_cls = cls.EXTENSION_FIELDS.get(field)
+
+            if not extension_cls:
+                raise ValueError("unknown field: {field}".format(field=field))
+            
+            yield extension_cls, field, value
 
     @classmethod
     def build (cls, name, domain,
@@ -175,8 +138,8 @@
             alias=None, alias4=None, alias6=None,
             forward=None, reverse=None,
             down=None,
-            boot=None,
             extensions={ },
+            **extra
     ) :
         """
             Return a Host initialized from data attributes.
@@ -184,6 +147,13 @@
             This handles all string parsing to our data types.
         """
 
+        extension_classes = { }
+
+        for extension_cls, field, value in cls.parse_extensions(extensions, extra):
+            extension_classes.setdefault(extension_cls, dict())[field] = value
+
+        extensions = {extension_cls.EXTENSION: extension_cls.build(**params) for extension_cls, params in extension_classes.iteritems()}
+
         return cls(name,
                 domain      = domain,
                 ip4         = parse_ip(ip, ipaddr.IPv4Address),
@@ -197,8 +167,7 @@
                 forward     = parse_str(forward),
                 reverse     = parse_str(reverse),
                 down        = parse_bool(down),
-                boot        = parse_dhcp_boot(boot),
-                extensions  = dict(cls.build_extensions(extensions)),
+                extensions  = extensions,
         )
 
     def __init__ (self, name, domain,
@@ -209,7 +178,6 @@
             alias=(), alias4=(), alias6=(),
             forward=None, reverse=None,
             down=None,
-            boot=None,
             extensions={},
     ):
         """
@@ -242,7 +210,6 @@
         self.alias6 = alias6
         self.owner = owner
         self.location = location
-        self.boot = boot
         self.forward = forward
         self.reverse = reverse
         self.down = down
@@ -296,6 +263,9 @@
         Provides default no-op behaviours for extension hooks.
     """
 
+    EXTENSION = None
+    EXTENSION_FIELDS = ()
+
     def addresses (self):
         """
             Yield additional (sublabel, ipaddr) records.
@@ -303,10 +273,12 @@
 
         return ()
 
-def extension (cls):
+def register_extension (cls):
     """
         Register an extension class
     """
 
     Host.EXTENSIONS[cls.EXTENSION] = cls
 
+    for field in cls.EXTENSION_FIELDS:
+        Host.EXTENSION_FIELDS[field] = cls
--- a/pvl/hosts/interface.py	Mon Mar 09 23:31:13 2015 +0200
+++ b/pvl/hosts/interface.py	Tue Mar 10 00:11:43 2015 +0200
@@ -16,8 +16,8 @@
     def __str__(self):
         return self.name
     
-@pvl.hosts.extension
-class HostInterfaces(object):
+@pvl.hosts.host.register_extension
+class HostInterfaces(pvl.hosts.host.HostExtension):
     """
         A host with multiple sub-interfaces.
 
--- a/pvl/hosts/tests.py	Mon Mar 09 23:31:13 2015 +0200
+++ b/pvl/hosts/tests.py	Tue Mar 10 00:11:43 2015 +0200
@@ -21,11 +21,23 @@
                 hosts_include_trace = None,
         )
 
+    def assertHostExtensionEquals(self, host, extension, value):
+        host_extension = host.extensions.get(extension)
+
+        self.assertIsNotNone(host_extension, (host, extension))
+
+        for attr, value in value.iteritems():
+            self.assertEquals(getattr(host_extension, attr), value, (host, extension, value))
+
     def assertHostEqual(self, host, host_str, attrs):
         self.assertEquals(str(host), host_str)
 
         for attr, value in attrs.iteritems():
-            self.assertEquals(getattr(host, attr), value)
+            if attr == 'extensions':
+                for extension, value in value.iteritems():
+                    self.assertHostExtensionEquals(host, extension, value)
+            else:
+                self.assertEquals(getattr(host, attr), value)
 
     def assertHostsEqual(self, hosts, expected):
         hosts = list(hosts)
@@ -75,18 +87,28 @@
  
 
     def testApplyHostConfigExtensions(self):
+        @pvl.hosts.host.register_extension
+        class HostLinkTest(pvl.hosts.host.HostExtension):
+            EXTENSION = 'link'
+
+            @classmethod
+            def build(cls, uplink, downlink):
+                obj = cls()
+                obj.uplink = uplink
+                obj.downlink = downlink
+
+                return obj
+
         host = config.apply_host('foo', 'test', {
-            'link:50':          'foo@test',
+            'link:downlink.50': 'foo@test',
             'link:uplink.49':   'bar@test',
         })
 
         self.assertHostEqual(host, 'foo@test', dict(
-                extensions = {
-                    'link': {
-                        '50': 'foo@test',
+                extensions = dict(link={
+                        'downlink': { '50': 'foo@test' },
                         'uplink': { '49': 'bar@test' },
-                    },
-                },
+                }),
         ))
    
     def testApplyHostFqdn(self):
@@ -155,7 +177,9 @@
                 ('foo@test', dict(
                     ip4         = ipaddr.IPAddress('192.0.2.1'),
                     ethernet    = { 'eth0': '00:11:22:33:44:55' },
-                    boot        = { 'next-server': 'boot.lan', 'filename': '/pxelinux.0' },
+                    extensions  = dict(
+                        dhcp        = { 'next_server': 'boot.lan', 'filename': '/pxelinux.0' }
+                    ),
                 )),
         ])