pvl.dns-zone: update and split out pvl.dns-process for --serial and --include-path transforms
authorTero Marttila <tero.marttila@aalto.fi>
Fri, 27 Feb 2015 16:46:33 +0200
changeset 641 9d36e312e6a7
parent 640 620d5a3beec4
child 642 c25834508569
pvl.dns-zone: update and split out pvl.dns-process for --serial and --include-path transforms
bin/pvl.dns-process
bin/pvl.dns-zone
pvl/dns/process.py
pvl/dns/tests.py
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/bin/pvl.dns-process	Fri Feb 27 16:46:33 2015 +0200
@@ -0,0 +1,47 @@
+#!/usr/bin/env python
+
+"""
+    Process bind zonefiles, without altering their structure.
+
+    Takes a zonefile as input, and gives a semantically identical zonefile as output, with the given changes.
+"""
+
+import logging; log = logging.getLogger('pvl.dns-process')
+import optparse
+import pvl.args
+import pvl.dns.process
+import pvl.dns.zone
+
+def main (argv):
+    parser = optparse.OptionParser(main.__doc__)
+    parser.add_option_group(pvl.args.parser(parser))
+    parser.add_option_group(pvl.dns.process.optparser(parser))
+
+    parser.add_option('--serial',               metavar='YYMMDDXX',
+            help="Set serial for SOA record")
+
+    parser.add_option('--include-path',         metavar='PATH',
+            help="Rewrite includes to given absolute path")
+
+    # input
+    options, args = pvl.args.parse(parser, argv)
+    
+    # process
+    zone = list(pvl.dns.process.apply_zone(options, args))
+
+    if options.serial:
+        log.info("Set zone serial: %s", options.serial)
+
+        zone = list(pvl.dns.process.zone_serial(zone, options.serial))
+
+    if options.include_path:
+        log.info("Set zone include path: %s", options.include_path)
+
+        zone = list(pvl.dns.process.zone_includes(zone, options.include_path))
+    
+    pvl.dns.process.apply_zone_output(options, zone)
+
+    return 0
+
+if __name__ == '__main__':
+    pvl.args.main(main)
--- a/bin/pvl.dns-zone	Fri Feb 27 16:45:53 2015 +0200
+++ b/bin/pvl.dns-zone	Fri Feb 27 16:46:33 2015 +0200
@@ -1,104 +1,19 @@
 #!/usr/bin/env python
 
 """
-    Process bind zonefiles.
+    Generate bind zonefiles from a given input zonefile.
 
-    Takes a zonefile as input, and gives a zonefile as output.
+    Takes a zonefile as input, and gives a new zonefile as output.
 """
 
-import logging; log = logging.getLogger('pvl.dns-zone')
+import ipaddr
+import logging; log = logging.getLogger('pvl.dns-generate')
 import optparse
-import os.path
 import pvl.args
-import pvl.dns
-import pvl.dns.zone
-
-from pvl.hosts import __version__
-
-def parse_options (argv) :
-    """
-        Parse command-line arguments.
-    """
-
-    prog = argv[0]
-
-    parser = optparse.OptionParser(
-            prog        = prog,
-            usage       = '%prog: [options]',
-            version     = __version__,
-
-            # module docstring
-            description = __doc__,
-    )
-
-    # logging
-    parser.add_option_group(pvl.args.parser(parser))
-
-    # input/output
-    parser.add_option('--input-charset',        metavar='CHARSET',  default='utf-8', 
-            help="Encoding used for input files")
-
-    parser.add_option('-o', '--output',         metavar='FILE',     default=None,
-            help="Write to output file; default stdout")
-
-    parser.add_option('--output-charset',       metavar='CHARSET',  default='utf-8', 
-            help="Encoding used for output files")
-
-    # check stage
-    parser.add_option('--check-hosts',          action='store_true',
-            help="Check that host/IPs are unique. Use --quiet to silence warnings, and test exit status")
-
-    parser.add_option('--check-exempt',         metavar='HOST', action='append',
-            help="Allow given names to have multiple records")
+import pvl.dns.reverse
+import pvl.dns.process
 
-    # forward stage
-    parser.add_option('--forward-zone',         action='store_true', 
-            help="Generate forward zone")
-
-    parser.add_option('--forward-txt',          action='store_true',
-            help="Generate TXT records for forward zone")
-
-    # reverse stage
-    parser.add_option('--reverse-domain',       metavar='DOMAIN',
-            help="Domain to use for hosts in reverse zone")
-
-    parser.add_option('--reverse-zone',         metavar='NET',
-            help="Generate forward zone for given subnet (x.z.y | a:b:c:d)")
-
-    # other
-    parser.add_option('--serial',               metavar='YYMMDDXX',
-            help="Set serial for SOA record")
-
-    parser.add_option('--include-path',         metavar='PATH',
-            help="Rewrite includes to given absolute path")
-
-    # defaults
-    parser.set_defaults(
-        # XXX: combine
-        check_exempt        = [],
-    )
-    
-    # parse
-    options, args = parser.parse_args(argv[1:])
-
-    # apply
-    pvl.args.apply(options, prog)
-
-    return options, args
-
-def apply_zone_input (options, args) :
-    """
-        Yield ZoneLine, ZoneRecord pairs from files.
-    """
-
-    for file in pvl.args.apply_files(args, 'r', options.input_charset) :
-        log.info("Reading zone: %s", file)
-
-        for line, record in pvl.dns.zone.ZoneLine.load(file):
-            yield line, record
-
-# TODO: --check-types to limit this to A/AAAA/CNAME etc
-def check_zone_hosts (zone, whitelist=None, whitelist_types=set(['TXT'])) :
+def check_zone (rrs, whitelist_names=set(), whitelist_types=set()):
     """
         Parse host/IP pairs from the zone, and verify that they are unique.
 
@@ -108,198 +23,124 @@
     by_name = {}
     by_ip = {}
 
-    fail = None
-
-    last_name = None
-
-    for l, r in zone :
-        if r :
-            name = r.name or last_name
-
-            name = (r.origin, name)
-
-            # name
-            if r.type not in whitelist_types :
-                if name not in by_name :
-                    by_name[name] = r
-
-                elif r.name in whitelist :
-                    log.debug("Duplicate whitelist entry: %s", r)
-
-                else :
-                    # fail!
-                    log.warn("%s: Duplicate name: %s <-> %s", r.line, r, by_name[name])
-                    fail = True
-
-            # ip
-            if r.type == 'A' :
-                ip, = r.data
-
-                if ip not in by_ip :
-                    by_ip[ip] = r
-
-                else :
-                    # fail!
-                    log.warn("%s: Duplicate IP: %s <-> %s", r.line, r, by_ip[ip])
-                    fail = True
-    
-    if fail :
-        log.error("Check failed, see warnings")
-        sys.exit(2)
+    check = True
 
-    yield l, r
-
-def process_zone_serial (zone, serial) :
-    """
-        Update the serial in the SOA record.
-    """
-
-    for line, rr in zone :
-        if rr and rr.type == 'SOA' :
-            # XXX: as SOA record..
-            try :
-                soa = pvl.dns.zone.SOA.parse(line)
-            except TypeError as error :
-                log.exception("%s: unable to parse SOA: %s", rr.name, rr)
-                sys.exit(2)
+    for rr in rrs:
+        name = (rr.origin, rr.name)
 
-            yield line, pvl.dns.zone.SOA(
-                    soa.master, soa.contact,
-                    serial, soa.refresh, soa.retry, soa.expire, soa.nxttl
-            )
-        else :
-            yield line, rr
-
-def process_zone_forwards (zone, txt=False) :
-    """
-        Process zone data -> forward zone data.
-    """
+        # name
+        if name not in by_name:
+            pass
 
-    for line, r in zone :
-        yield line, r
+        elif rr.type in whitelist_types:
+            log.debug("%s: Whitelist type duplicate: %s", rr, by_name[name])
 
-        if r and r.type == 'A' :
-            if txt :
-                # comment?
-                comment = r.line.comment
+        elif rr.name in whitelist_names:
+            log.debug("%s: Whitelist name duplicate: %s", rr, by_name[name])
 
-                if comment :
-                    yield line, ZoneRecord.TXT(None, comment, ttl=r.ttl)
-           
-def process_zone_reverse (zone, origin, domain) :
+        else:
+            log.warn("%s: Duplicate name: %s <-> %s", rr.line, rr, by_name[name])
+            check = False
+            
+        by_name[name] = rr
+
+        # ip
+        if rr.type in ('A', 'AAAA'):
+            ip, = rr.data
+
+            if ip in by_ip:
+                log.warn("%s: Duplicate IP: %s <-> %s", rr.line, rr, by_ip[ip])
+                check = False
+                
+            by_ip[ip] = rr
+
+    return check
+
+def process_zone_reverse (rrs, prefix):
     """
         Process zone data -> reverse zone data.
     """
 
-    for line, r in zone :
-        if r and r.type == 'A' :
+    for r in rrs:
+        if r.type == 'A':
             ip, = r.data
-            ptr = reverse_ipv4(ip)
 
-        elif r and r.type == 'AAAA' :
+            ip = ipaddr.IPv4Address(ip)
+
+        elif r.type == 'AAAA':
             ip, = r.data
-            ptr = reverse_ipv6(ip)
             
-        else :
-            yield line, r
-            continue
-
-        # verify
-        if zone and ptr.endswith(origin) :
-            ptr = ptr[:-(len(origin) + 1)]
-
-        else :
-            log.warning("Reverse does not match zone origin, skipping: (%s) -> %s <-> %s", ip, ptr, origin)
+            ip = ipaddr.IPv6Address(ip)
+            
+        else:
             continue
 
-        # domain to use
-        host_domain = r.origin or domain
-        host_fqdn = fqdn(name, host_domain)
-
-        yield line, ZoneRecord.PTR(ptr, host_fqdn)
-
-def process_zone_includes (options, zone, path) :
-    """
-        Rewrite include paths in zones.
-    """
+        if ip not in prefix:
+            log.debug("%s: skip: %s not in %s", rr, ip, prefix)
+            continue
 
-    for line, rr in zone :
-        if line.parts[0] == '$INCLUDE' :
-            _, include = line.parts
+        ptr = pvl.dns.reverse_label(prefix, ip)
+        fqdn = pvl.dns.fqdn(r.name, r.origin)
 
-            yield pvl.dns.zone.ZoneLine(
-                    line.file,
-                    line.lineno, 
-                    line.line,
-                    line.indent,
-                    ['$INCLUDE', '"{path}"'.format(path=os.path.join(path, include))],
-            ), rr
-        else :
-            yield line, rr
+        yield pvl.dns.ZoneRecord.PTR(ptr, fqdn)
+
+def main (argv):
+    parser = optparse.OptionParser(main.__doc__)
+    parser.add_option_group(pvl.args.parser(parser))
+    parser.add_option_group(pvl.dns.process.optparser(parser))
+
+    parser.add_option('--zone-origin',          metavar='DOMAIN',
+            help="Domain to use for hosts in zone")
+
+    # check stage
+    parser.add_option('--check-hosts',          action='store_true',
+            help="Check that host/IPs are unique. Use --quiet to silence warnings, and test exit status")
+
+    parser.add_option('--check-exempt',         metavar='HOST', action='append',
+            help="Allow given names to have multiple records")
+
+    # reverse stage
+    parser.add_option('--reverse-prefix',         metavar='NET',
+            help="Generate forward zone for given subnet (192.0.2, 2001:db8)")
 
 
-def apply_zone_output (options, zone) :
-    """
-        Write out the resulting zonefile.
-    """
-
-    file = pvl.args.apply_file(options.output, 'w', options.output_charset)
-
-    for line, r in zone :
-        if r :
-            file.write(unicode(r))
-        else :
-            file.write(unicode(line))
-        file.write('\n')
-
-def main (argv) :
-    options, args = parse_options(argv)
-    
-    # input
-    zone = apply_zone_input(options, args)
-   
-    if options.check_hosts :
-        whitelist = set(options.check_exempt)
-
-        log.info("Checking hosts: whitelist=%r", whitelist)
-
-        zone = list(check_zone_hosts(zone, whitelist=whitelist))
+    parser.set_defaults(
+        check_exempt        = [
+            '@'
+        ],
+    )
 
-    if options.serial :
-        log.info("Set zone serial: %s", options.serial)
-
-        zone = list(process_zone_serial(zone, serial=options.serial))
-
-    if options.forward_zone :
-        log.info("Generate forward zone...")
-
-        zone = list(process_zone_forwards(zone, txt=options.forward_txt))
-
-    if options.reverse_zone :
-        if ':' in options.reverse_zone :
-            # IPv6
-            origin = reverse_ipv6(options.reverse_zone)
+    # input
+    options, args = pvl.args.parse(parser, argv)
 
-        else :
-            # IPv4
-            origin = reverse_ipv4(options.reverse_zone)
-
-        domain = options.reverse_domain
+    if options.reverse_prefix and not options.zone_origin:
+        log.error("--reverse-prefix requires --zone-origin")
+    
+    zone = list(pvl.dns.process.apply_zone_records(options, options.zone_origin, args))
 
-        if not domain :
-            log.error("--reverse-zone requires --reverse-domain")
-            return 1
+    # check
+    if options.check_hosts:
+        whitelist_names = set(options.check_exempt)
 
-        zone = list(process_zone_reverse(zone, origin=origin, domain=domain))
-    
-    if options.include_path :
-        zone = list(process_zone_includes(options, zone, options.include_path))
+        log.info("Checking hosts: whitelist_names=%r", whitelist_names)
 
-    # output
-    apply_zone_output(options, zone)
+        if not check_zone(zone, whitelist_names=whitelist_names):
+            log.error("Check zone failed, see warnings")
+            return 2
+
+    # transform
+    if options.reverse_prefix:
+        prefix = pvl.dns.reverse.parse_prefix(options.reverse_prefix)
+
+        zone = list(process_zone_reverse(zone, prefix))
+    else:
+        # pass through
+        pass
+
+    pvl.dns.process.apply_zone_output(options, zone)
 
     return 0
 
 if __name__ == '__main__':
-    import sys
-    sys.exit(main(sys.argv))
+    pvl.args.main(main)
+
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/pvl/dns/process.py	Fri Feb 27 16:46:33 2015 +0200
@@ -0,0 +1,86 @@
+import logging
+import optparse
+import os.path
+import pvl.args
+
+from pvl.dns import zone
+
+log = logging.getLogger('pvl.dns.process')
+
+def optparser (parser):
+    group = optparse.OptionGroup(parser, "Hosts config files")
+    group.add_option('--input-charset',        metavar='CHARSET',  default='utf-8', 
+            help="Encoding used for input files")
+
+    group.add_option('-o', '--output',         metavar='FILE',     default=None,
+            help="Write to output file; default stdout")
+
+    group.add_option('--output-charset',       metavar='CHARSET',  default='utf-8', 
+            help="Encoding used for output files")
+
+    return group
+
+def zone_serial (rrs, serial) :
+    """
+        Update the serial in the SOA record.
+    """
+
+    for rr in rrs:
+        if isinstance(rr, zone.ZoneRecordSOA) and rr.type == 'SOA':
+            yield zone.ZoneRecordSOA.build(rr.name, rr.type,
+                    rr.master, rr.contact,
+                    serial, rr.refresh, rr.retry, rr.expire, rr.nxttl,
+                    line=rr.line, origin=rr.origin, comment=rr.comment,
+            )
+        else:
+            yield rr
+
+def zone_includes (rrs, includes_path):
+    """
+        Rewrite include paths in zones.
+    """
+
+    for rr in rrs:
+        if isinstance(rr, zone.ZoneDirective) and rr.directive == 'INCLUDE':
+            include_path, = rr.arguments
+
+            yield zone.ZoneDirective.INCLUDE(os.path.join(includes_path, include_path))
+        else:
+            yield rr
+
+def apply_zone_output (options, zone):
+    """
+        Output given ZoneDirective/ZoneRecord items to the output file/stdout.
+    """
+
+    file = pvl.args.apply_file(options.output, 'w', options.output_charset)
+
+    for item in zone:
+        file.write(unicode(item))
+        file.write('\n')
+
+def apply_zone (options, args):
+    """
+        ZoneLine.load() in given zones.
+
+        Yields ZoneDirective/ZoneRecord items.
+    """
+
+    for file in pvl.args.apply_files(args, 'r', options.input_charset) :
+        log.info("%s: reading zone", file.name)
+
+        for item in zone.ZoneLine.load(file):
+            yield item
+
+def apply_zone_records (options, origin, args) :
+    """
+        ZoneRecord.load() in given zones.
+
+        Yields expanded ZoneRecord items.
+    """
+
+    for file in pvl.args.apply_files(args, 'r', options.input_charset) :
+        log.info("%s: expanding zone", file.name)
+
+        for item in zone.ZoneRecord.load(file, origin):
+            yield item
--- a/pvl/dns/tests.py	Fri Feb 27 16:45:53 2015 +0200
+++ b/pvl/dns/tests.py	Fri Feb 27 16:46:33 2015 +0200
@@ -2,7 +2,7 @@
 import re
 import unittest
 
-from pvl.dns import zone
+from pvl.dns import zone, process
 from StringIO import StringIO
 
 class File(StringIO):
@@ -28,12 +28,15 @@
         for zl, expect in itertools.izip_longest(zls, expected):
             self.assertZoneLineEqual(zl, expect)
 
-    def assertZoneRecordEqual(self, zr, expect):
+    def assertZoneItemEqual(self, zr, expect):
         self.assertEqualWhitespace(unicode(zr), expect)
 
-    def assertZoneRecordsEqual(self, zrs, expected):
+    def assertZoneEqual(self, zrs, expected):
         for zr, expect in itertools.izip_longest(zrs, expected):
-            self.assertZoneRecordEqual(zr, expect)
+            self.assertZoneItemEqual(zr, expect)
+
+    assertZoneRecordEqual = assertZoneItemEqual
+    assertZoneRecordsEqual = assertZoneEqual
 
 class ZoneLineTest(TestMixin, unittest.TestCase):
     def testZoneLine(self):
@@ -180,3 +183,38 @@
                 "bar 3600 A 192.0.2.2", # quux.test
             ]
         )
+
+class TestProcessZone(TestMixin, unittest.TestCase):
+    def testZoneRecordLoad(self):
+        rrs = zone.ZoneLine.load(File("""
+$TTL 3600
+@                   SOA     foo.test. hostmaster.test. (
+                            0               ; serial
+                            1d              ; refresh
+                            5m              ; retry
+                            10d             ; expiry
+                            300             ; negative
+                    )
+
+                    NS      foo
+                    NS      bar
+
+foo                 A       192.0.2.1
+bar                 A       192.0.2.2
+
+$INCLUDE "includes/test"
+"""))
+
+        rrs = list(process.zone_serial(rrs, 1337))
+        rrs = list(process.zone_includes(rrs, '...'))
+
+        self.assertZoneEqual(rrs, [
+            "$TTL 3600",
+            "@ SOA foo.test. hostmaster.test. 1337 1d 5m 10d 300",
+            " NS foo",
+            " NS bar",
+            "foo A 192.0.2.1",
+            "bar A 192.0.2.2",
+            "$INCLUDE \".../includes/test\"",
+        ])
+