author Tero Marttila <tero.marttila@aalto.fi>
Thu, 26 Feb 2015 16:36:46 +0200
changeset 501 41b362e6074b
parent 497 0082d2092d1f
child 502 ac4e0f2df80c
permissions -rw-r--r--
pvl.hosts.config: fix and test extensions
import ipaddr
import pvl.args
import unittest

from pvl.hosts import config, dhcp, zone
from pvl.hosts.host import Host
from StringIO import StringIO

class ConfFile(StringIO):
    def __init__(self, name, buffer):
        StringIO.__init__(self, buffer)
        self.name = name

class TestConfig(unittest.TestCase):
    def setUp(self):
        self.options = pvl.args.options(
                hosts_charset   = 'utf-8',
                hosts_domain    = None,
                hosts_include   = None,

    def assertHostEqual(self, host, host_str, attrs):
        self.assertEquals(str(host), host_str)

        for attr, value in attrs.iteritems():
            self.assertEquals(getattr(host, attr), value)

    def assertHostsEqual(self, hosts, expected):
        for host, expect in zip(hosts, expected):
            host_str, attrs = expect

            self.assertHostEqual(host, host_str, attrs)
    def testApplyHostsFileError(self):
        with self.assertRaises(config.HostConfigError):
            list(config.apply_hosts(self.options, ['nonexistant']))

    def testApplyHosts(self):
        conf_file = ConfFile('test', """
    ip =

    ip =
        expected = [
                ('foo@test', dict(ip=ipaddr.IPAddress(''))),
                ('bar@test', dict(ip=ipaddr.IPAddress(''))),

        self.assertHostsEqual(config.apply_hosts_file(self.options, conf_file), expected)

    def testApply(self):
        self.assertHostsEqual(config.apply(self.options, ['etc/hosts/test']), [
                ('foo@test', dict(
                    ip          = ipaddr.IPAddress(''),
                    ethernet    = {None: '00:11:22:33:44:55'},
                ('bar@test', dict(
                    ip          = ipaddr.IPAddress(''),
                    ethernet    = {None: '01:23:45:67:89:ab'},

    def testApplyHostsExpand(self):
        self.assertHostsEqual(config.apply_host_config(self.options, 'asdf', 'asdf{1-3}', ip='10.100.100.$'), [
                ('asdf1@asdf', dict(ip=ipaddr.IPAddress(''))),
                ('asdf2@asdf', dict(ip=ipaddr.IPAddress(''))),
                ('asdf3@asdf', dict(ip=ipaddr.IPAddress(''))),

    def testApplyHostConfigDict(self):
        host = config.apply_host(self.options, 'foo', 'test', {
            'ethernet.eth0': '00:11:22:33:44:55',

        self.assertHostEqual(host, 'foo@test', dict(
                ethernet    = { 'eth0': '00:11:22:33:44:55' }
    def testApplyHostsConfigError(self):
        with self.assertRaises(config.HostConfigError):
            config.apply_host(self.options, 'foo', 'test', {
                'ethernet': 'foo',
                'ethernet.eth0': 'bar',

    def testApplyHostConfigExtensions(self):
        host = config.apply_host(self.options, 'foo', 'test', {
            'link:50':          'foo@test',
            'link:uplink.49':   'bar@test',

        self.assertHostEqual(host, 'foo@test', dict(
                extensions = {
                    'link': {
                        '50': 'foo@test',
                        'uplink': { '49': 'bar@test' },

class TestZoneMixin(object):
    def assertZoneEquals(self, rrs, expected):
        gather = { }

        for rr in rrs:
            key = (rr.name.lower(), rr.type.upper())

            self.assertNotIn(key, gather)

            gather[key] = rr.data

        self.assertDictEqual(gather, expected)

class TestForwardZone(TestZoneMixin, unittest.TestCase):
    def testHostOutOfOrigin(self):
        h = Host('host', 'domain', ip=ipaddr.IPAddress(''))

        self.assertZoneEquals(zone.host_forward(h, 'test'), { })

    def testHostIP(self):
        h = Host.build('host', 'domain',
                ip  = '',
                ip6 = '2001:db8::',

        self.assertZoneEquals(zone.host_forward(h, 'domain'), {
            ('host', 'A'): [''],
            ('host', 'AAAA'): ['2001:db8::c000:201'],
    def testHostAlias(self):
        h = Host.build('host', 'domain',
                ip      = '',
                alias   = 'test *.test',

        self.assertEquals(h.alias, ['test', '*.test'])

        self.assertZoneEquals(zone.host_forward(h, 'domain'), {
            ('host', 'A'): [''],
            ('test', 'CNAME'): ['host'],
            ('*.test', 'CNAME'): ['host'],

    def testHostAlias46(self):
        h = Host.build('host', 'domain',
                ip      = '',
                ip6     = '2001:db8::',
                alias4  = 'test4',
                alias6  = 'test6',

        self.assertZoneEquals(zone.host_forward(h, 'domain'), {
            ('host', 'A'): [''],
            ('host', 'AAAA'): ['2001:db8::c000:201'],
            ('test4', 'A'): [''],
            ('test6', 'AAAA'): ['2001:db8::c000:201'],

    def testHostAlias4Missing(self):
        h = Host.build('host', 'domain',
                ip6     = '2001:db8::',
                alias4  = 'test4',
                alias6  = 'test6',

        with self.assertRaises(zone.HostZoneError):
            self.assertZoneEquals(zone.host_forward(h, 'domain'), { })

    def testHostAlias6Missing(self):
        h = Host.build('host', 'domain',
                ip      = '',
                alias4  = 'test4',
                alias6  = 'test6',

        with self.assertRaises(zone.HostZoneError):
            self.assertZoneEquals(zone.host_forward(h, 'domain'), { })

    def testHostFQDN(self):
        h = Host.build('host.example.net', None,
                ip          = '',

        self.assertZoneEquals(zone.host_forward(h, 'example.com'), {


    def testHostDelegate(self):
        h = Host.build('host', 'example.com',
                forward = 'host.example.net',

        self.assertZoneEquals(zone.host_forward(h, 'example.com'), {
            ('host', 'CNAME'): ['host.example.net.'],

    def testHostForwardAlias(self):
        h = Host.build('host', 'domain',
                forward = 'host.example.net',
                alias   = 'test',

        self.assertZoneEquals(zone.host_forward(h, 'domain'), {
            ('host', 'CNAME'): ['host.example.net.'],
            ('test', 'CNAME'): ['host'],

    def testHostLocation(self):
        h = Host.build('host', 'domain',
                ip          = '',
                location    = 'test',

        self.assertEquals(h.location, ('test', 'domain'))

        self.assertZoneEquals(zone.host_forward(h, 'domain'), {
            ('host', 'A'): [''],
            ('test', 'CNAME'): ['host'],

    def testHostLocationDomain(self):
        h = Host.build('host', 'foo.domain',
                ip          = '',
                location    = 'test@bar.domain',

        self.assertEquals(h.location, ('test', 'bar.domain'))

        self.assertZoneEquals(zone.host_forward(h, 'domain'), {
            ('host.foo', 'A'): [''],
            ('test.bar', 'CNAME'): ['host.foo'],

    def testHostLocationDomainOutOfOrigin(self):
        h = Host.build('host', 'foo.domain',
                ip          = '',
                location    = 'test@bar.domain',

        self.assertEquals(h.location, ('test', 'bar.domain'))

        with self.assertRaises(zone.HostZoneError):
            self.assertZoneEquals(zone.host_forward(h, 'foo.domain'), {
                ('host', 'A'): [''],
        # TODO
        #self.assertZoneEquals(zone.host_forward(h, 'bar.domain'), {
        #    ('test', 'CNAME'): ['host.foo'],

    def testHostsForward(self):
        hosts = [
                Host.build('foo', 'domain',
                    ip      = '',
                    ip6     = '2001:db8::',
                    alias   = 'test',
                Host.build('bar', 'domain',
                    ip      = '',
                Host.build('quux', 'example',
                    ip      = '',
        rrs = zone.apply_hosts_forward(hosts, 'domain', add_origin=True)
        # handle the $ORIGIN directive
        rd = next(rrs)

        self.assertEquals(unicode(rd), '$ORIGIN\tdomain.')

        self.assertZoneEquals(rrs, {
            ('foo', 'A'): [''],
            ('foo', 'AAAA'): ['2001:db8::c000:201'],
            ('test', 'CNAME'): ['foo'],
            ('bar', 'A'): [''],

    def testHostsConflict(self):
        hosts = [
                Host.build('foo', 'domain',
                    ip      = '',
                Host.build('foo', 'domain',
                    ip      = '',
        with self.assertRaises(zone.HostZoneError):
            self.assertZoneEquals(zone.apply_hosts_forward(hosts, 'domain'), { })

    def testHostsAliasConflict(self):
        hosts = [
                Host.build('foo', 'domain',
                    ip          = '',
                Host.build('bar', 'domain',
                    ip          = '',
                    alias       = 'foo',
        # with A first
        with self.assertRaises(zone.HostZoneError):
            self.assertZoneEquals(zone.apply_hosts_forward(hosts, 'domain'), { })
        # also with CNAME first
        with self.assertRaises(zone.HostZoneError):
            self.assertZoneEquals(zone.apply_hosts_forward(reversed(hosts), 'domain'), { })

    def testHostsAlias4Conflict(self):
        hosts = [
                Host.build('foo', 'domain',
                    ip          = '',
                Host.build('bar', 'domain',
                    ip          = '',
                    alias4      = 'foo',
        with self.assertRaises(zone.HostZoneError):
            self.assertZoneEquals(zone.apply_hosts_forward(hosts, 'domain'), { })

class TestReverseZone(TestZoneMixin, unittest.TestCase):
    def testHostIP(self):
        h = Host.build('host', 'domain',
                ip  = '',
                ip6 = '2001:db8::',

        self.assertZoneEquals((rr for ip, rr in zone.host_reverse(h, ipaddr.IPNetwork(''))), {
            ('1', 'PTR'): ['host.domain.'],

        self.assertZoneEquals((rr for ip, rr in zone.host_reverse(h, ipaddr.IPNetwork('2001:db8::/64'))), {
            ('', 'PTR'): ['host.domain.'],

    def testHostIP4(self):
        h = Host.build('host', 'domain',
                ip  = '',

        self.assertZoneEquals((rr for ip, rr in zone.host_reverse(h, ipaddr.IPNetwork(''))), {
            ('1', 'PTR'): ['host.domain.'],
        self.assertZoneEquals((rr for ip, rr in zone.host_reverse(h, ipaddr.IPNetwork(''))), {
            ('1.2', 'PTR'): ['host.domain.'],
        self.assertZoneEquals((rr for ip, rr in zone.host_reverse(h, ipaddr.IPNetwork(''))), {
            ('1.2.0', 'PTR'): ['host.domain.'],

        self.assertZoneEquals((rr for ip, rr in zone.host_reverse(h, ipaddr.IPNetwork('2001:db8::/64'))), {


    def testHostIP6(self):
        h = Host.build('host', 'domain',
                ip6 = '2001:db8::',

        self.assertZoneEquals((rr for ip, rr in zone.host_reverse(h, ipaddr.IPNetwork(''))), {

        self.assertZoneEquals((rr for ip, rr in zone.host_reverse(h, ipaddr.IPNetwork('2001:db8::/64'))), {
            ('', 'PTR'): ['host.domain.'],

    def testHostIPOutOfPrefix(self):
        h = Host.build('host', 'domain',
                ip  = '',
                ip6 = '2001:db8::',

        self.assertZoneEquals((rr for ip, rr in zone.host_reverse(h, ipaddr.IPNetwork(''))), {


        self.assertZoneEquals((rr for ip, rr in zone.host_reverse(h, ipaddr.IPNetwork('2001:db8:1::/64'))), {


    def testHostFQDN(self):
        h = Host.build('host.example.net', None,
                ip          = '',
                ip6         = '2001:db8::',

        self.assertZoneEquals((rr for ip, rr in zone.host_reverse(h, ipaddr.IPNetwork(''))), {
            ('3', 'PTR'): ['host.example.net.'],

        self.assertZoneEquals((rr for ip, rr in zone.host_reverse(h, ipaddr.IPNetwork('2001:db8::/64'))), {
            ('', 'PTR'): ['host.example.net.'],

    def testHostDelegate(self):
        h = Host.build('host', 'example.com',
                ip      = '',
                ip6     = '2001:db8::',
                forward = '',
                reverse = '1.0/',

        self.assertZoneEquals(zone.host_forward(h, 'example.com'), {


        self.assertZoneEquals((rr for ip, rr in zone.host_reverse(h, ipaddr.IPNetwork(''))), {
            ('1', 'CNAME'): ['1.0/'],
        self.assertZoneEquals((rr for ip, rr in zone.host_reverse(h, ipaddr.IPNetwork('2001:db8::/64'))), {


    def testHosts(self):
        hosts = [
                Host.build('foo', 'domain',
                    ip      = '',
                Host.build('bar', 'domain',
                    ip      = '',
        self.assertZoneEquals(zone.apply_hosts_reverse(hosts, ipaddr.IPNetwork('')), {
            ('1', 'PTR'): ['foo.domain.'],
            ('2', 'PTR'): ['bar.domain.'],
        # in ip order
        self.assertZoneEquals(zone.apply_hosts_reverse(reversed(hosts), ipaddr.IPNetwork('')), {
            ('1', 'PTR'): ['foo.domain.'],
            ('2', 'PTR'): ['bar.domain.'],

    def testHostsConflict(self):
        hosts = [
                Host.build('foo', 'domain',
                    ip      = '',
                Host.build('bar', 'domain',
                    ip      = '',
        with self.assertRaises(zone.HostZoneError):
            self.assertZoneEquals(zone.apply_hosts_reverse(hosts, ipaddr.IPNetwork('')), { })

    def testHostsGenerateUnknown(self):
        hosts = [
                Host.build('foo', 'domain',
                    ip      = '',
                Host.build('bar', 'domain',
                    ip      = '',
        self.assertZoneEquals(zone.apply_hosts_reverse(hosts, ipaddr.IPNetwork(''),
                unknown_host = 'ufc',
                unknown_domain = 'domain',
        ), {
            ('1', 'PTR'): ['foo.domain.'],
            ('2', 'PTR'): ['ufc.domain.'],
            ('3', 'PTR'): ['ufc.domain.'],
            ('4', 'PTR'): ['ufc.domain.'],
            ('5', 'PTR'): ['bar.domain.'],
            ('6', 'PTR'): ['ufc.domain.'],

class TestDhcp(unittest.TestCase):
    def assertBlocksEqual(self, blockdefs, expected):
        for (_block, _items, _opts), (block, items, opts) in zip(blockdefs, expected):
            self.assertEqual(_block, block)
            self.assertItemsEqual(_items, items)

            if opts is not None:
                self.assertEqual(_opts, opts)
        self.assertEqual(len(blockdefs), len(expected))
    def testHost(self):
        host = Host.build('foo', 'test',
                ip          = '',
                ethernet    = '00:11:22:33:44:55',
                owner       = 'foo',

        self.assertBlocksEqual(list(dhcp.dhcp_host(host)), [
            (('host', 'foo'), [
                ('option', 'host-name', "foo"),
                ('fixed-address', ''),
                ('hardware', 'ethernet', '00:11:22:33:44:55'),
                ], dict(comment="Owner: foo"))

    def testHostStatic(self):
        host = Host.build('foo', 'test',
                ip          = '',

        self.assertBlocksEqual(list(dhcp.dhcp_host(host)), [


    def testHostDynamic(self):
        host = Host.build('foo', 'test',
                ethernet    = '00:11:22:33:44:55',

        self.assertBlocksEqual(list(dhcp.dhcp_host(host)), [
            (('host', 'foo'), [
                ('option', 'host-name', "foo"),
                ('hardware', 'ethernet', '00:11:22:33:44:55'),
            ], None)

    def testHostBoot(self):
        hosts = [
                Host.build('foo1', 'test',
                        ethernet    = '00:11:22:33:44:55',
                        boot        = 'boot.lan:debian/wheezy/pxelinux.0',
                Host.build('foo2', 'test',
                        ethernet    = '00:11:22:33:44:55',
                        boot        = 'boot.lan:',
                Host.build('foo3', 'test',
                        ethernet    = '00:11:22:33:44:55',
                        boot        = '/debian/wheezy/pxelinux.0',

        self.assertBlocksEqual(list(dhcp.dhcp_hosts(hosts)), [
            (('host', 'foo1'), [
                ('option', 'host-name', "foo1"),
                ('hardware', 'ethernet', '00:11:22:33:44:55'),
                ('next-server', 'boot.lan'),
                ('filename', 'debian/wheezy/pxelinux.0'),
            ], None),
            (('host', 'foo2'), [
                ('option', 'host-name', "foo2"),
                ('hardware', 'ethernet', '00:11:22:33:44:55'),
                ('next-server', 'boot.lan'),
            ], None),
            (('host', 'foo3'), [
                ('option', 'host-name', "foo3"),
                ('hardware', 'ethernet', '00:11:22:33:44:55'),
                ('filename', 'debian/wheezy/pxelinux.0'),
            ], None),

    def testHosts(self):
        hosts = [
                Host.build('foo', 'test',
                        ip          = '',
                        ethernet    = '00:11:22:33:44:55',
                Host.build('bar', 'test',
                        ip          = '',
                        ethernet    = '01:23:45:67:89:ab',

        self.assertBlocksEqual(list(dhcp.dhcp_hosts(hosts)), [
            (('host', 'foo'), [
                ('option', 'host-name', "foo"),
                ('fixed-address', ''),
                ('hardware', 'ethernet', '00:11:22:33:44:55'),
            ], None),
            (('host', 'bar'), [
                ('option', 'host-name', "bar"),
                ('fixed-address', ''),
                ('hardware', 'ethernet', '01:23:45:67:89:ab'),
            ], None),

    def testHostConflict(self):
        hosts = [
                Host.build('foo', 'test1',
                        ethernet    = '00:11:22:33:44:55',
                Host.build('foo', 'test2',
                        ethernet    = '01:23:45:67:89:ab',
        with self.assertRaises(dhcp.HostDHCPError):

    def testHostMultinet(self):
        hosts = [
                Host.build('foo', 'test1',
                    ip              = '',
                    ethernet        = { 'eth1': '00:11:22:33:44:55' },
                Host.build('foo', 'test2',
                    ip              = '',
                    ethernet        = { 'eth2': '01:23:45:67:89:ab' },
        self.assertBlocksEqual(list(dhcp.dhcp_hosts(hosts)), [
                (('host', 'foo-eth1'), [
                    ('option', 'host-name', "foo"),
                    ('fixed-address', ''),
                    ('hardware', 'ethernet', '00:11:22:33:44:55'),
                ], None),
                (('host', 'foo-eth2'), [
                    ('option', 'host-name', "foo"),
                    ('fixed-address', ''),
                    ('hardware', 'ethernet', '01:23:45:67:89:ab'),
                ], None),

if __name__ == '__main__':