bin/pvl.dns-zone
changeset 293 6351acf3eb3b
parent 258 1ad9cec4f556
child 294 29720bbc5379
--- a/bin/pvl.dns-zone	Mon Dec 16 23:12:28 2013 +0200
+++ b/bin/pvl.dns-zone	Mon Dec 16 23:16:32 2013 +0200
@@ -97,6 +97,20 @@
 
     return options, args
 
+def apply_zone_input (options, args) :
+    """
+        Yield ZoneLine, ZoneRecord pairs from files.
+    """
+
+    for file in pvl.args.apply_files(args, 'r', options.input_charset) :
+        log.info("Reading zone: %s", file)
+
+        for line, record in pvl.dns.zone.ZoneLine.load(file, 
+                line_timestamp_prefix   = options.input_line_date,
+        ) :
+            yield line, record
+
+# TODO: --check-types to limit this to A/AAAA/CNAME etc
 def check_zone_hosts (zone, whitelist=None, whitelist_types=set(['TXT'])) :
     """
         Parse host/IP pairs from the zone, and verify that they are unique.
@@ -111,74 +125,82 @@
 
     last_name = None
 
-    for r in zone :
-        name = r.name or last_name
-
-        name = (r.origin, name)
-
-        # name
-        if r.type not in whitelist_types :
-            if name not in by_name :
-                by_name[name] = r
-
-            elif r.name in whitelist :
-                log.debug("Duplicate whitelist entry: %s", r)
-
-            else :
-                # fail!
-                log.warn("%s: Duplicate name: %s <-> %s", r.line, r, by_name[name])
-                fail = True
+    for l, r in zone :
+        if r :
+            name = r.name or last_name
 
-        # ip
-        if r.type == 'A' :
-            ip, = r.data
-
-            if ip not in by_ip :
-                by_ip[ip] = r
+            name = (r.origin, name)
 
-            else :
-                # fail!
-                log.warn("%s: Duplicate IP: %s <-> %s", r.line, r, by_ip[ip])
-                fail = True
+            # name
+            if r.type not in whitelist_types :
+                if name not in by_name :
+                    by_name[name] = r
 
-    return fail
+                elif r.name in whitelist :
+                    log.debug("Duplicate whitelist entry: %s", r)
 
-def process_zone_soa (soa, serial) :
-    return pvl.dns.zone.SOA(
-        soa.master, soa.contact,
-        serial, soa.refresh, soa.retry, soa.expire, soa.nxttl
-    )
+                else :
+                    # fail!
+                    log.warn("%s: Duplicate name: %s <-> %s", r.line, r, by_name[name])
+                    fail = True
+
+            # ip
+            if r.type == 'A' :
+                ip, = r.data
+
+                if ip not in by_ip :
+                    by_ip[ip] = r
+
+                else :
+                    # fail!
+                    log.warn("%s: Duplicate IP: %s <-> %s", r.line, r, by_ip[ip])
+                    fail = True
+    
+    if fail :
+        log.error("Check failed, see warnings")
+        sys.exit(2)
+
+    yield l, r
 
 def process_zone_serial (zone, serial) :
-    for rr in zone :
-        if rr.type == 'SOA' :
+    """
+        Update the serial in the SOA record.
+    """
+
+    for line, rr in zone :
+        if rr and rr.type == 'SOA' :
             # XXX: as SOA record..
-            yield process_zone_soa(pvl.dns.zone.SOA.parse(rr.line), serial)
+            soa = pvl.dns.zone.SOA.parse(line)
+
+            yield line, pvl.dns.zone.SOA(
+                    soa.master, soa.contact,
+                    serial, soa.refresh, soa.retry, soa.expire, soa.nxttl
+            )
         else :
-            yield rr
+            yield line, rr
 
 def process_zone_forwards (zone, txt=False, mx=False) :
     """
         Process zone data -> forward zone data.
     """
 
-    for r in zone :
-        yield r
+    for line, r in zone :
+        yield line, r
 
-        if r.type == 'A' :
+        if r and r.type == 'A' :
             if txt :
                 # comment?
                 comment = r.line.comment
 
                 if comment :
-                    yield ZoneRecord.TXT(None, comment, ttl=r.ttl)
+                    yield line, ZoneRecord.TXT(None, comment, ttl=r.ttl)
 
            
             # XXX: RP, do we need it?
 
             if mx :
                 # XXX: is this even a good idea?
-                yield ZoneRecord.MX(None, 10, mx, ttl=r.ttl)
+                yield line, ZoneRecord.MX(None, 10, mx, ttl=r.ttl)
 
 def process_zone_meta (zone, ignore=None) :
     """
@@ -187,7 +209,7 @@
     
     TIMESTAMP_FORMAT = '%Y/%m/%d'
     
-    for r in zone :
+    for line, r in zone :
         if ignore and r.name in ignore :
             # skip
             log.debug("Ignore record: %s", r)
@@ -199,29 +221,24 @@
             timestamp = r.line.timestamp
 
             if timestamp :
-                yield ZoneRecord.TXT(r.name, timestamp.strftime(TIMESTAMP_FORMAT), ttl=r.ttl)
+                yield line, ZoneRecord.TXT(r.name, timestamp.strftime(TIMESTAMP_FORMAT), ttl=r.ttl)
      
 def process_zone_reverse (zone, origin, domain) :
     """
         Process zone data -> reverse zone data.
     """
 
-    name = None
-
-    for r in zone :
-        # keep name from previous..
-        if r.name :
-            name = r.name
-
-        if r.type == 'A' :
+    for line, r in zone :
+        if r and r.type == 'A' :
             ip, = r.data
             ptr = reverse_ipv4(ip)
 
-        elif r.type == 'AAAA' :
+        elif r and r.type == 'AAAA' :
             ip, = r.data
             ptr = reverse_ipv6(ip)
             
         else :
+            yield line, r
             continue
 
         # verify
@@ -236,57 +253,47 @@
         host_domain = r.origin or domain
         host_fqdn = fqdn(name, host_domain)
 
-        yield ZoneRecord.PTR(ptr, host_fqdn)
+        yield line, ZoneRecord.PTR(ptr, host_fqdn)
 
-def write_zone_records (file, zone) :
-    for r in zone :
-        file.write(unicode(r))
+def apply_zone_output (options, zone) :
+    """
+        Write out the resulting zonefile.
+    """
+
+    file = pvl.args.apply_file(options.output, 'w', options.output_charset)
+
+    for line, r in zone :
+        if r :
+            file.write(unicode(r))
+        else :
+            file.write(line.line)
         file.write('\n')
 
 def main (argv) :
     options, args = parse_options(argv)
     
-    # open files, default to stdout
-    input_files = pvl.args.apply_files(args, 'r', options.input_charset)
+    # input
+    zone = apply_zone_input(options, args)
    
-    # process zone data
-    zone = []
-
-    for file in input_files :
-        log.info("Reading zone: %s", file)
-
-        zone += list(pvl.dns.zone.ZoneRecord.load(file, 
-            line_timestamp_prefix   = options.input_line_date,
-        ))
-
-    # check?
     if options.check_hosts :
         whitelist = set(options.check_exempt)
 
-        log.debug("checking hosts; whitelist=%r", whitelist)
+        log.info("Checking hosts: whitelist=%r", whitelist)
 
-        if check_zone_hosts(zone, whitelist=whitelist) :
-            log.warn("Hosts check failed")
-            return 2
-
-        else :
-            log.info("Hosts check OK")
+        zone = list(check_zone_hosts(zone, whitelist=whitelist))
 
     if options.serial :
         log.info("Set zone serial: %s", options.serial)
 
         zone = list(process_zone_serial(zone, serial=options.serial))
 
-    # output file
-    output = open_file(options.output, 'w', options.output_charset)
-
     if options.forward_zone :
-        log.info("Write forward zone: %s", output)
+        log.info("Generate forward zone...")
 
         zone = list(process_zone_forwards(zone, txt=options.forward_txt, mx=options.forward_mx))
 
-    elif options.meta_zone :
-        log.info("Write metadata zone: %s", output)
+    if options.meta_zone :
+        log.info("Generate metadata zone...")
 
         if not options.input_line_date :
             log.error("--meta-zone requires --input-line-date")
@@ -294,7 +301,7 @@
 
         zone = list(process_zone_meta(zone, ignore=set(options.meta_ignore)))
 
-    elif options.reverse_zone :
+    if options.reverse_zone :
         if ':' in options.reverse_zone :
             # IPv6
             origin = reverse_ipv6(options.reverse_zone)
@@ -310,16 +317,9 @@
             return 1
 
         zone = list(process_zone_reverse(zone, origin=origin, domain=domain))
-
-    elif options.check_hosts :
-        # we only did that, done
-        return 0
-
-    else :
-        # pass-through
-        log.info("Passing through zonefile")
-
-    write_zone_records(output, zone)
+    
+    # output
+    apply_zone_output(options, zone)
 
     return 0