pvl/dns/tests.py
author Tero Marttila <tero.marttila@aalto.fi>
Tue, 03 Mar 2015 12:14:22 +0200
changeset 716 4fecd0d1cf23
parent 648 8e3e6be9ac70
permissions -rw-r--r--
pvl.dns.process: merge --include-trace into pvl.dns-process, replacing pvl.dns-includes
import itertools
import re
import unittest

from pvl.dns import zone, process
from StringIO import StringIO

class File(StringIO):
    @classmethod
    def lines (cls, *lines):
        return cls('\n'.join(lines) + '\n')

    def __init__(self, buffer, name='test.file'):
        StringIO.__init__(self, buffer)
        self.name = name

class TestMixin(object):
    def assertEqualWhitespace(self, value, expected):
        # normalize
        value = re.sub(r'\s+', ' ', value)

        self.assertEqual(value, expected)

    def assertZoneLineEqual(self, zl, expected):
        self.assertEqual(unicode(zl), expected)
    
    def assertZoneLinesEqual(self, zls, expected):
        for zl, expect in itertools.izip_longest(zls, expected):
            self.assertZoneLineEqual(zl, expect)

    def assertZoneItemEqual(self, zr, expect):
        self.assertEqualWhitespace(unicode(zr), expect)

    def assertZoneEqual(self, zrs, expected):
        for zr, expect in itertools.izip_longest(zrs, expected):
            self.assertZoneItemEqual(zr, expect)

    assertZoneRecordEqual = assertZoneItemEqual
    assertZoneRecordsEqual = assertZoneEqual

class ZoneLineTest(TestMixin, unittest.TestCase):
    def testZoneLine(self):
        self.assertEqual(unicode(zone.ZoneLine(['foo', 'A', '192.0.2.1'])), "foo\tA\t192.0.2.1")
    
    def testZoneLineParse(self):
        self.assertZoneLinesEqual(zone.ZoneLine.parse(File("foo A 192.0.2.1")), ["foo\tA\t192.0.2.1"])
    
    def testZoneLineParseWhitespace(self):
        self.assertZoneLinesEqual(zone.ZoneLine.parse(File("foo \t A\t192.0.2.1")), ["foo\tA\t192.0.2.1"])
    
    def testZoneLineParseComment(self):
        self.assertZoneLinesEqual(zone.ZoneLine.parse(File("foo A 192.0.2.1   ;   bar")), ["foo\tA\t192.0.2.1\t; bar"])

    def testZoneLineParseQuoteSimple(self):
        self.assertZoneLinesEqual(zone.ZoneLine.parse(File("foo TXT \"asdf quux\"")), ["foo\tTXT\tasdf quux"])

    def testZoneLineParseQuoteTrailing(self):
        self.assertZoneLinesEqual(zone.ZoneLine.parse(File("foo TXT \"asdf quux\" ok ; yay")), ["foo\tTXT\tasdf quux\tok\t; yay"])
    
    def testZoneLineParseIndent(self):
        self.assertZoneLinesEqual(zone.ZoneLine.parse(File(" A 192.0.2.2")), ["\tA\t192.0.2.2"])
    
    def testZoneLineParseMultilineSingle(self):
        self.assertZoneLinesEqual(zone.ZoneLine.parse(File("@ SOA ( 1 2 3 4 5 )")), ["@\tSOA\t1\t2\t3\t4\t5"])
    
    def testZoneLineParseMultiline(self):
        self.assertZoneLinesEqual(zone.ZoneLine.parse(File.lines("@ SOA (", "\t1\t", "2", "\t3\t\t4", "\t5 )")), ["@\tSOA\t1\t2\t3\t4\t5"])

    def testZoneLineLoad(self):
        self.assertZoneRecordsEqual(
            zone.ZoneLine.load(File.lines(
                "$TTL 3600",
                " ",
                "@ NS ns1",
                "  NS ns2",
                "$ORIGIN asdf", # note lack of .
                "foo A 192.0.2.1",
                "$ORIGIN quux.test.", # note lack of .
                "bar A 192.0.2.2",
            )), [
                "$TTL 3600",
                "@ NS ns1",
                " NS ns2",
                "$ORIGIN asdf",
                "foo A 192.0.2.1",
                "$ORIGIN quux.test.",
                "bar A 192.0.2.2",
            ]
        )

class ZoneDirectiveTest(TestMixin, unittest.TestCase):
    def testZoneDirective(self):
        self.assertEqual(unicode(zone.ZoneDirective('ORIGIN', ['foo.'])), "$ORIGIN\tfoo.")
    
    def testZoneDirectiveBuild(self):
        self.assertEqual(unicode(zone.ZoneDirective.build('ORIGIN', 'foo.')), "$ORIGIN\tfoo.")

    def testZoneDirectiveParse(self):
        self.assertEqual(unicode(zone.ZoneDirective.parse(['$ORIGIN', 'foo.'])), "$ORIGIN\tfoo.")
    
    def testZoneDirectiveParseUpper(self):
        self.assertEqual(unicode(zone.ZoneDirective.parse(['$include', 'foo.zone'])), "$INCLUDE\tfoo.zone")
    
    def testZoneDirectiveComment(self):
        self.assertEqual(unicode(zone.ZoneDirective('ORIGIN', ['foo.'], comment="bar")), "$ORIGIN\tfoo.\t; bar")

class ZoneRecordTest(TestMixin, unittest.TestCase):
    def testZoneRecordShort(self):
        rr = zone.ZoneRecord('test', None, None, 'A', ['192.0.2.1'])

        self.assertZoneRecordEqual((rr), 'test A 192.0.2.1')

    def testZoneRecordImplicit(self):
        rr = zone.ZoneRecord(None, None, None, 'A', ['192.0.2.1'])

        self.assertZoneRecordEqual((rr), ' A 192.0.2.1')

    def testZoneRecordFull(self):
        rr = zone.ZoneRecord('test', 60, 'IN', 'A', ['192.0.2.1'])

        self.assertZoneRecordEqual((rr), 'test 60 IN A 192.0.2.1')
 
    def testZoneRecordComment(self):
        rr = zone.ZoneRecord('test', None, None, 'A', ['192.0.2.1'], comment='Testing')

        self.assertZoneRecordEqual((rr), 'test A 192.0.2.1 ; Testing')

    def testZoneRecordA(self):
        self.assertZoneRecordEqual((zone.ZoneRecord.A('test', '192.0.2.1')), "test A 192.0.2.1")

    def testZoneRecordAAAA(self):
        self.assertZoneRecordEqual((zone.ZoneRecord.AAAA('test', '2001:db8::c000:201')), "test AAAA 2001:db8::c000:201")
    
    def testZoneRecordCNAME(self):
        self.assertZoneRecordEqual((zone.ZoneRecord.CNAME('test', 'test.example.net.')), "test CNAME test.example.net.")
    
    def testZoneRecordTXT(self):
        self.assertZoneRecordEqual((zone.ZoneRecord.TXT('test', u"Foo Bar")), u"test TXT \"Foo Bar\"")
    
    def testZoneRecordPTR(self):
        self.assertZoneRecordEqual(zone.ZoneRecord.PTR('1', 'test.example.com.'), "1 PTR test.example.com.")
    
    def testZoneRecordMX(self):
        self.assertZoneRecordEqual(zone.ZoneRecord.MX('@', 10, 'mail1'), "@ MX 10 mail1")

    def testZoneRecordBuildTTL(self):
        self.assertZoneRecordEqual(zone.ZoneRecord.A('test', '192.0.2.1', ttl=60), "test 60 A 192.0.2.1")

    def testZoneRecordBuildZeroTTL(self):
        self.assertZoneRecordEqual(zone.ZoneRecord.A('test', '192.0.2.1', ttl=0), "test 0 A 192.0.2.1")
    
    def testZoneRecordBuildImplicit(self):
        self.assertZoneRecordEqual(zone.ZoneRecord.A(None, '192.0.2.1'), " A 192.0.2.1")
    
    def testZoneRecordBuildCls(self):
        self.assertZoneRecordEqual(zone.ZoneRecord.A('foo', '192.0.2.1', cls='in'), "foo IN A 192.0.2.1")

    def testZoneRecordParse(self):
        self.assertZoneRecordEqual(zone.ZoneRecord.parse(['test', 'A', '192.0.2.1']), "test A 192.0.2.1")
        self.assertZoneRecordEqual(zone.ZoneRecord.parse(['', 'A', '192.0.2.1']), " A 192.0.2.1")
        self.assertZoneRecordEqual(zone.ZoneRecord.parse(['test', '60', 'A', '192.0.2.1']), "test 60 A 192.0.2.1")
        self.assertZoneRecordEqual(zone.ZoneRecord.parse(['test', '60', 'in', 'A', '192.0.2.1']), "test 60 IN A 192.0.2.1")
    
    def testZoneRecordParseError(self):
        with self.assertRaises(zone.ZoneLineError):
            self.assertZoneRecordEqual(zone.ZoneRecord.parse(['test', 'A']), None)

    def testZoneRecordLoad(self):
        self.assertZoneRecordsEqual(
            zone.ZoneRecord.load(File.lines(
                "$TTL 3600",
                " ",
                "@ NS ns1",
                "  NS ns2",
                "$ORIGIN asdf", # relative
                "foo A 192.0.2.1",
                "$ORIGIN quux.test.", # absolute
                "bar A 192.0.2.2",
            ), 'test'), [
                "@ 3600 NS ns1",
                "@ 3600 NS ns2",
                "foo 3600 A 192.0.2.1", # asdf.test
                "bar 3600 A 192.0.2.2", # quux.test
            ]
        )

class TestProcessZone(TestMixin, unittest.TestCase):
    def testZoneRecordLoad(self):
        rrs = zone.ZoneLine.load(File("""
$TTL 3600
@                   SOA     foo.test. hostmaster.test. (
                            0               ; serial
                            1d              ; refresh
                            5m              ; retry
                            10d             ; expiry
                            300             ; negative
                    )

                    NS      foo
                    NS      bar

foo                 A       192.0.2.1
bar                 A       192.0.2.2

$INCLUDE "includes/test"
"""))
        
        include_trace = [ ]
        rrs = list(process.zone_serial(rrs, 1337))
        rrs = list(process.zone_includes(rrs, '...', include_trace))

        self.assertZoneEqual(rrs, [
            "$TTL 3600",
            "@ SOA foo.test. hostmaster.test. 1337 1d 5m 10d 300",
            " NS foo",
            " NS bar",
            "foo A 192.0.2.1",
            "bar A 192.0.2.2",
            "$INCLUDE .../includes/test",
        ])
        self.assertEqual(include_trace, [
            '.../includes/test',
        ])