process-zone: support $ORIGIN, $GENERATE for --reverse-zone
authorTero Marttila <terom@paivola.fi>
Mon, 19 Mar 2012 16:50:28 +0200
changeset 35 840092ee4d97
parent 34 d2e187c1f548
child 36 3208cd6540dc
process-zone: support $ORIGIN, $GENERATE for --reverse-zone
bin/process-zone
--- a/bin/process-zone	Mon Mar 19 13:59:34 2012 +0200
+++ b/bin/process-zone	Mon Mar 19 16:50:28 2012 +0200
@@ -85,6 +85,10 @@
     parser.add_option('--reverse-zone',         metavar='NET',
             help="Generate forward zone for given subnet (x.z.y)")
 
+    # 
+    parser.add_option('--doctest',              action='store_true',
+            help="Run module doctests")
+
     # defaults
     parser.set_defaults(
         loglevel            = logging.WARN,
@@ -105,6 +109,17 @@
 
     return options, args
 
+class ZoneError (Exception) :
+    pass
+
+class ZoneLineError (ZoneError) :
+    """
+        ZoneLine-related error
+    """
+
+    def __init__ (self, line, msg, *args, **kwargs) :
+        super(ZoneLineError, self).__init__("%s: %s" % (line, msg.format(*args, **kwargs)))
+
 class ZoneLine (object) :
     """
         A line in a zonefile.
@@ -136,7 +151,7 @@
 
         if line_timestamp_prefix :
             if ': ' not in line :
-                raise Exception("Missing timestamp prefix on line: %s:%d: %s" % (file, lineno, line))
+                raise ZoneError("%s:%d: Missing timestamp prefix: %s" % (file, lineno, line))
 
             # split prefix
             prefix, line = line.split(': ', 1)
@@ -202,6 +217,9 @@
     # the underlying line
     line = None
 
+    # possible $ORIGIN context
+    origin = None
+
     # record fields
     name = None
     type = None
@@ -214,18 +232,18 @@
     cls = None
 
     @classmethod
-    def parse (cls, line) :
+    def parse (cls, line, parts=None, origin=None) :
         """
             Parse from ZoneLine. Returns None if there is no record on the line..
         """
 
-        if not line.parts :
+        if parts is None :
+            parts = list(line.parts)
+
+        if not parts :
             # skip
             return
         
-        # consume parts
-        parts = list(line.parts)
-
         # indented lines don't have name
         if line.indent :
             name = None
@@ -233,7 +251,10 @@
         else :
             name = parts.pop(0)
         
-        log.debug("  name=%r", name)
+        log.debug("  name=%r, origin=%r", name, origin)
+
+        if len(parts) < 2 :
+            raise ZoneLineError(line, "Too few parts to parse: {0!r}", line.data)
 
         # parse ttl/cls/type
         ttl = _cls = None
@@ -253,12 +274,13 @@
         log.debug("  ttl=%r, cls=%r, type=%r, data=%r", ttl, _cls, type, data)
 
         return cls(name, type, data,
+            origin  = origin,
             ttl     = ttl,
             cls     = _cls,
             line    = line,
         )
 
-    def __init__ (self, name, type, data, ttl=None, cls=None, line=None, comment=None) :
+    def __init__ (self, name, type, data, origin=None, ttl=None, cls=None, line=None, comment=None) :
         self.name = name
         self.type = type
         self.data = data
@@ -266,6 +288,7 @@
         self.ttl = ttl
         self.cls = cls
         
+        self.origin = origin
         self.line = line
 
         # XXX: within line
@@ -292,7 +315,7 @@
         )
 
     def __str__ (self) :
-        return ' '.join((self.name, self.type, ' '.join(self.data)))
+        return ' '.join((self.name or '', self.type, ' '.join(self.data)))
 
 class TXTRecord (ZoneRecord) :
     """
@@ -305,32 +328,234 @@
             **opts
         )
 
-def parse_record (path, lineno, line, **opts) :
+class OffsetValue (object) :
+    def __init__ (self, value) :
+        self.value = value
+
+    def __getitem__ (self, offset) :
+        value = self.value + offset
+
+        #log.debug("OffsetValue: %d[%d] -> %d", self.value, offset, value)
+
+        return value
+
+def parse_generate_field (line, field) :
     """
-        Parse (name, ttl, type, data, comment) from bind zonefile.
+        Parse a $GENERATE lhs/rhs field:
+            $
+            ${<offset>[,<width>[,<base>]]}
+            \$
+            $$
 
-        Returns None for empty/comment lines.
+        Returns a wrapper that builds the field-value when called with the index.
+        
+        >>> parse_generate_field(None, "foo")(1)
+        'foo'
+        >>> parse_generate_field(None, "foo-$")(1)
+        'foo-1'
+        >>> parse_generate_field(None, "foo-$$")(1)
+        'foo-$'
+        >>> parse_generate_field(None, "\$")(1)
+        '$'
+        >>> parse_generate_field(None, "10.0.0.${100}")(1)
+        '10.0.0.101'
+        >>> parse_generate_field(None, "foo-${0,2,d}")(1)
+        'foo-01'
+
     """
 
-    # line
-    line = ZoneLine.parse(path, lineno, line, **opts)
-    record = ZoneRecord.parse(line)
+    input = field
+    expr = []
 
-    if record :
-        return record
+    while '$' in field :
+        # defaults
+        offset = 0
+        width = 0
+        base = 'd'
+        escape = False
 
-def parse_zone_records (file, **opts) :
+        # different forms
+        if '${' in field :
+            pre, body = field.split('${', 1)
+            body, post = body.split('}', 1)
+
+            # parse body
+            parts = body.split(',')
+
+            # offset
+            offset = int(parts.pop(0))
+
+            # width
+            if parts :
+                width = int(parts.pop(0))
+
+            # base
+            if parts :
+                base = parts.pop(0)
+            
+            if parts:
+                # fail
+                raise ZoneLineError(line, "extra data in ${...} body: {0!r}", parts)
+
+        elif '$$' in field :
+            pre, post = field.split('$$', 1)
+            escape = True
+
+        elif '\\$' in field :
+            pre, post = field.split('\\$', 1)
+            escape = True
+
+        else :
+            pre, post = field.split('$', 1)
+        
+        expr.append(pre)
+
+        if escape :
+            expr.append('$')
+
+        else :
+            # meta-format
+            fmt = '{value[%d]:0%d%s}' % (offset, width, base)
+
+            log.debug("field=%r -> pre=%r, fmt=%r, post=%r", field, expr, fmt, post)
+
+            expr.append(fmt)
+
+        field = post
+
+    # final
+    if field :
+        expr.append(field)
+    
+    # combine
+    expr = ''.join(expr)
+
+    log.debug("%s: %s", input, expr)
+
+    # processed
+    def value_func (value) :
+        # magic wrapper to implement offsets
+        return expr.format(value=OffsetValue(value))
+    
+    return value_func
+
+def process_generate (line, origin, parts) :
+    """
+        Process a 
+            $GENERATE <start>-<stop>[/<step>] lhs [ttl] [class] type rhs [comment]
+        directive into a series of ZoneResource's.
+    """
+
+    range = parts.pop(0)
+
+    # parse range
+    if '/' in range :
+        range, step = range.split('/')
+        step = int(step)
+    else :
+        step = 1
+
+    start, stop = range.split('-')
+    start = int(start)
+    stop = int(stop)
+
+    log.debug("  range: start=%r, stop=%r, step=%r", start, stop, step)
+
+    # inclusive
+    range = xrange(start, stop + 1, step)
+
+    lhs_func = parse_generate_field(line, parts.pop(0))
+    rhs_func = parse_generate_field(line, parts.pop(-1))
+    body = parts
+
+    for i in range :
+        # build
+        parts = [lhs_func(i)] + body + [rhs_func(i)]
+
+        log.debug(" %03d: %r", i, parts)
+
+        # parse
+        yield ZoneRecord.parse(line, parts=parts, origin=origin)
+
+def parse_zone_records (file, origin=None, **opts) :
     """
         Parse ZoneRecord items from the given zonefile, ignoring non-record lines.
     """
+
+    ttl = None
+
+    skip_multiline = False
     
-    for lineno, line in enumerate(file) :
-        record = parse_record(file.name, lineno, line, **opts)
+    for lineno, raw_line in enumerate(file) :
+        # parse comment
+        if ';' in raw_line :
+            line, comment = raw_line.split(';', 1)
+        else :
+            line = raw_line
+            comment = None
+
+        # XXX: handle multi-line statements...
+        # start
+        if '(' in line :
+            skip_multiline = True
+            
+            log.warn("%s:%d: Start of multi-line statement: %s", file.name, lineno, raw_line)
+
+        # end?
+        if ')' in line :
+            skip_multiline = False
+            
+            log.warn("%s:%d: End of multi-line statement: %s", file.name, lineno, raw_line)
+            
+            continue
+
+        elif skip_multiline :
+            log.warn("%s:%d: Multi-line statement: %s", file.name, lineno, raw_line)
+
+            continue
+        
+        # parse
+        line = ZoneLine.parse(file.name, lineno, raw_line, **opts)
+
+        if not line.data :
+            log.debug("%s: skip empty line: %s", line, raw_line)
+
+            continue
+
+        elif line.data.startswith('$') :
+            # control record
+            type = line.parts[0]
+
+            if type == '$ORIGIN':
+                # update
+                origin = line.parts[1]
+                
+                log.info("%s: origin: %s", line, origin)
+            
+            elif type == '$GENERATE':
+                # process...
+                log.info("%s: generate: %s", line, line.parts)
+
+                for record in process_generate(line, origin, line.parts[1:]) :
+                    yield record
+
+            else :
+                log.warning("%s: skip control record: %s", line, line.data)
+            
+            # XXX: passthrough!
+            continue
+
+        # normal record?
+        record = ZoneRecord.parse(line, origin=origin)
 
         if record :
             yield record
 
-def check_zone_hosts (zone, whitelist=None) :
+        else :
+            # unknown
+            log.warning("%s: skip unknown line: %s", line, line.data)
+
+def check_zone_hosts (zone, whitelist=None, whitelist_types=set(['TXT'])) :
     """
         Parse host/IP pairs from the zone, and verify that they are unique.
 
@@ -342,20 +567,25 @@
 
     fail = None
 
+    last_name = None
+
     for r in zone :
-        name = r.name
+        name = r.name or last_name
+
+        name = (r.origin, name)
 
         # name
-        if name not in by_name :
-            by_name[name] = r
+        if r.type not in whitelist_types :
+            if name not in by_name :
+                by_name[name] = r
 
-        elif r.name in whitelist :
-            log.debug("Duplicate whitelist entry: %s", r)
+            elif r.name in whitelist :
+                log.debug("Duplicate whitelist entry: %s", r)
 
-        else :
-            # fail!
-            log.warn("%s: Duplicate name: %s <-> %s", r.line, r, by_name[name])
-            fail = True
+            else :
+                # fail!
+                log.warn("%s: Duplicate name: %s <-> %s", r.line, r, by_name[name])
+                fail = True
 
         # ip
         if r.type == 'A' :
@@ -429,8 +659,13 @@
     return '.'.join([str(octet) for octet in reversed(octets)] + ['in-addr', 'arpa'])
 
 def fqdn (*parts) :
-    return '.'.join(parts) + '.'
-
+    fqdn = '.'.join(parts)
+    
+    # we may be given an fqdn in parts
+    if not fqdn.endswith('.') :
+        fqdn += '.'
+    
+    return fqdn
 
 def process_zone_reverse (zone, origin, domain) :
     """
@@ -455,8 +690,8 @@
             continue
 
         # domain to use
-        host_domain = domain
-        host_fqdn = fqdn(r.name, domain)
+        host_domain = r.origin or domain
+        host_fqdn = fqdn(r.name, host_domain)
 
         yield ZoneRecord(reverse, 'PTR', [host_fqdn])
 
@@ -488,6 +723,11 @@
     
     options, args = parse_options(argv)
 
+    if options.doctest :
+        import doctest
+        fail, total = doctest.testmod()
+        return fail
+
     if args :
         # open files
         input_files = [open_file(path, 'r', options.input_charset) for path in args]