bin/process-zone
author Tero Marttila <terom@paivola.fi>
Fri, 10 May 2013 00:05:25 +0300
changeset 80 b58236f9ea7b
parent 35 840092ee4d97
permissions -rwxr-xr-x
process-zone: support AAAA/ip6.arpa for --reverse-zone, as well as implicit record names
#!/usr/bin/env python

"""
    Process zonefiles.
"""

__version__ = '0.0.1-dev'

import optparse
import codecs
from datetime import datetime
import logging

import ipaddr

log = logging.getLogger('main')

# command-line options, global state
options = None

def parse_options (argv) :
    """
        Parse command-line arguments.
    """

    prog = argv[0]

    parser = optparse.OptionParser(
            prog        = prog,
            usage       = '%prog: [options]',
            version     = __version__,

            # module docstring
            description = __doc__,
    )

    # logging
    general = optparse.OptionGroup(parser, "General Options")

    general.add_option('-q', '--quiet',     dest='loglevel', action='store_const', const=logging.ERROR, help="Less output")
    general.add_option('-v', '--verbose',   dest='loglevel', action='store_const', const=logging.INFO,  help="More output")
    general.add_option('-D', '--debug',     dest='loglevel', action='store_const', const=logging.DEBUG, help="Even more output")

    parser.add_option_group(general)

    # input/output
    parser.add_option('-c', '--input-charset',  metavar='CHARSET',  default='utf-8', 
            help="Encoding used for input files")

    parser.add_option('-o', '--output',         metavar='FILE',     default='-',
            help="Write to output file; default stdout")

    parser.add_option('--output-charset',       metavar='CHARSET',  default='utf-8', 
            help="Encoding used for output files")

    # check stage
    parser.add_option('--check-hosts',          action='store_true',
            help="Check that host/IPs are unique. Use --quiet to silence warnings, and test exit status")

    parser.add_option('--check-exempt',         metavar='HOST', action='append',
            help="Allow given names to have multiple records")

    # meta stage
    parser.add_option('--meta-zone',            action='store_true',
            help="Generate host metadata zone; requires --input-line-date")

    parser.add_option('--meta-ignore',          metavar='HOST', action='append',
            help="Ignore given hostnames in metadata output")

    parser.add_option('--input-line-date',      action='store_true',
            help="Parse timestamp prefix from each input line (e.g. `hg blame | ...`)")

    # forward stage
    parser.add_option('--forward-zone',         action='store_true', 
            help="Generate forward zone")

    parser.add_option('--forward-txt',          action='store_true',
            help="Generate TXT records for forward zone")

    parser.add_option('--forward-mx',           metavar='MX',
            help="Generate MX records for forward zone")

    # reverse stage
    parser.add_option('--reverse-domain',       metavar='DOMAIN',
            help="Domain to use for hosts in reverse zone")

    parser.add_option('--reverse-zone',         metavar='NET',
            help="Generate forward zone for given subnet (x.z.y | a:b:c:d)")

    # 
    parser.add_option('--doctest',              action='store_true',
            help="Run module doctests")

    # defaults
    parser.set_defaults(
        loglevel            = logging.WARN,

        # XXX: combine
        check_exempt        = [],
        meta_ignore         = [],
    )
    
    # parse
    options, args = parser.parse_args(argv[1:])

    # configure
    logging.basicConfig(
        format  = prog + ': %(name)s: %(levelname)s %(funcName)s : %(message)s',
        level   = options.loglevel,
    )

    return options, args

class ZoneError (Exception) :
    pass

class ZoneLineError (ZoneError) :
    """
        ZoneLine-related error
    """

    def __init__ (self, line, msg, *args, **kwargs) :
        super(ZoneLineError, self).__init__("%s: %s" % (line, msg.format(*args, **kwargs)))

class ZoneLine (object) :
    """
        A line in a zonefile.
    """

    file = None
    lineno = None

    # data
    indent = None # was the line indented?
    data = None
    parts = None # split line fields

    # optional
    timestamp = None
    comment = None

    PARSE_DATETIME_FORMAT = '%Y-%m-%d'

    @classmethod
    def parse (cls, file, lineno, line, line_timestamp_prefix=False) :
        """
            Parse out given line and build.
        """

        log.debug("parse: %s:%d: %s", file, lineno, line)

        ts = None

        if line_timestamp_prefix :
            if ': ' not in line :
                raise ZoneError("%s:%d: Missing timestamp prefix: %s" % (file, lineno, line))

            # split prefix
            prefix, line = line.split(': ', 1)

            # parse it out
            ts = datetime.strptime(prefix, cls.PARSE_DATETIME_FORMAT)

            log.debug("  ts=%r", ts)

        # was line indented?
        indent = line.startswith(' ') or line.startswith('\t')
        
        # strip
        line = line.strip()
        
        log.debug("  indent=%r, line=%r", indent, line)

        # parse comment out?
        if ';' in line :
            line, comment = line.split(';', 1)

            line = line.strip()
            comment = comment.strip()

        else :
            line = line.strip()
            comment = None
        
        log.debug("  line=%r, comment=%r", line, comment)

        # parse fields
        if '"' in line :
            pre, data, post = line.split('"', 2)
            parts = pre.split() + [data] + post.split()
           
        else :
            parts = line.split()

        log.debug("  parts=%r", parts)

        # build
        return cls(file, lineno, indent, line, parts, timestamp=ts, comment=comment)

    def __init__ (self, file, lineno, indent, data, parts, timestamp=None, comment=None) :
        self.file = file
        self.lineno = lineno

        self.indent = indent
        self.data = data
        self.parts = parts

        self.timestamp = timestamp
        self.comment = comment

    def __str__ (self) :
        return "{file}:{lineno}".format(file=self.file, lineno=self.lineno)

class ZoneRecord (object) :
    """
        A record from a zonefile.
    """

    # the underlying line
    line = None

    # possible $ORIGIN context
    origin = None

    # record fields
    name = None
    type = None

    # list of data fields
    data = None

    # optional
    ttl = None
    cls = None

    @classmethod
    def parse (cls, line, parts=None, origin=None) :
        """
            Parse from ZoneLine. Returns None if there is no record on the line..
        """

        if parts is None :
            parts = list(line.parts)

        if not parts :
            # skip
            return
        
        # XXX: indented lines keep name from previous record
        if line.indent :
            name = None

        else :
            name = parts.pop(0)
        
        log.debug("  name=%r, origin=%r", name, origin)

        if len(parts) < 2 :
            raise ZoneLineError(line, "Too few parts to parse: {0!r}", line.data)

        # parse ttl/cls/type
        ttl = _cls = None

        if parts and parts[0][0].isdigit() :
            ttl = parts.pop(0)

        if parts and parts[0].upper() in ('IN', 'CH') :
            _cls = parts.pop(0)

        # always have type
        type = parts.pop(0)

        # remaining parts are data
        data = parts

        log.debug("  ttl=%r, cls=%r, type=%r, data=%r", ttl, _cls, type, data)

        return cls(name, type, data,
            origin  = origin,
            ttl     = ttl,
            cls     = _cls,
            line    = line,
        )

    def __init__ (self, name, type, data, origin=None, ttl=None, cls=None, line=None, comment=None) :
        self.name = name
        self.type = type
        self.data = data
        
        self.ttl = ttl
        self.cls = cls
        
        self.origin = origin
        self.line = line

        # XXX: within line
        self._comment = comment

    def build_line (self) :
        """
            Construct a zonefile-format line..."
        """

        # XXX: comment?
        if self._comment :
            comment = '\t; ' + self._comment
        else :
            comment = ''
            
        return u"{name:25} {ttl:4} {cls:2} {type:5} {data}{comment}".format(
                name    = self.name or '',
                ttl     = self.ttl or '',
                cls     = self.cls or '',
                type    = self.type,
                data    = ' '.join(unicode(data) for data in self.data),
                comment = comment,
        )

    def __str__ (self) :
        return ' '.join((self.name or '', self.type, ' '.join(self.data)))

class TXTRecord (ZoneRecord) :
    """
        TXT record.
    """

    def __init__ (self, name, text, **opts) :
        return super(TXTRecord, self).__init__(name, 'TXT', 
            [u'"{0}"'.format(text.replace('"', '\\"'))], 
            **opts
        )

class OffsetValue (object) :
    def __init__ (self, value) :
        self.value = value

    def __getitem__ (self, offset) :
        value = self.value + offset

        #log.debug("OffsetValue: %d[%d] -> %d", self.value, offset, value)

        return value

def parse_generate_field (line, field) :
    """
        Parse a $GENERATE lhs/rhs field:
            $
            ${<offset>[,<width>[,<base>]]}
            \$
            $$

        Returns a wrapper that builds the field-value when called with the index.
        
        >>> parse_generate_field(None, "foo")(1)
        'foo'
        >>> parse_generate_field(None, "foo-$")(1)
        'foo-1'
        >>> parse_generate_field(None, "foo-$$")(1)
        'foo-$'
        >>> parse_generate_field(None, "\$")(1)
        '$'
        >>> parse_generate_field(None, "10.0.0.${100}")(1)
        '10.0.0.101'
        >>> parse_generate_field(None, "foo-${0,2,d}")(1)
        'foo-01'

    """

    input = field
    expr = []

    while '$' in field :
        # defaults
        offset = 0
        width = 0
        base = 'd'
        escape = False

        # different forms
        if '${' in field :
            pre, body = field.split('${', 1)
            body, post = body.split('}', 1)

            # parse body
            parts = body.split(',')

            # offset
            offset = int(parts.pop(0))

            # width
            if parts :
                width = int(parts.pop(0))

            # base
            if parts :
                base = parts.pop(0)
            
            if parts:
                # fail
                raise ZoneLineError(line, "extra data in ${...} body: {0!r}", parts)

        elif '$$' in field :
            pre, post = field.split('$$', 1)
            escape = True

        elif '\\$' in field :
            pre, post = field.split('\\$', 1)
            escape = True

        else :
            pre, post = field.split('$', 1)
        
        expr.append(pre)

        if escape :
            expr.append('$')

        else :
            # meta-format
            fmt = '{value[%d]:0%d%s}' % (offset, width, base)

            log.debug("field=%r -> pre=%r, fmt=%r, post=%r", field, expr, fmt, post)

            expr.append(fmt)

        field = post

    # final
    if field :
        expr.append(field)
    
    # combine
    expr = ''.join(expr)

    log.debug("%s: %s", input, expr)

    # processed
    def value_func (value) :
        # magic wrapper to implement offsets
        return expr.format(value=OffsetValue(value))
    
    return value_func

def process_generate (line, origin, parts) :
    """
        Process a 
            $GENERATE <start>-<stop>[/<step>] lhs [ttl] [class] type rhs [comment]
        directive into a series of ZoneResource's.
    """

    range = parts.pop(0)

    # parse range
    if '/' in range :
        range, step = range.split('/')
        step = int(step)
    else :
        step = 1

    start, stop = range.split('-')
    start = int(start)
    stop = int(stop)

    log.debug("  range: start=%r, stop=%r, step=%r", start, stop, step)

    # inclusive
    range = xrange(start, stop + 1, step)

    lhs_func = parse_generate_field(line, parts.pop(0))
    rhs_func = parse_generate_field(line, parts.pop(-1))
    body = parts

    for i in range :
        # build
        parts = [lhs_func(i)] + body + [rhs_func(i)]

        log.debug(" %03d: %r", i, parts)

        # parse
        yield ZoneRecord.parse(line, parts=parts, origin=origin)

def parse_zone_records (file, origin=None, **opts) :
    """
        Parse ZoneRecord items from the given zonefile, ignoring non-record lines.
    """

    ttl = None

    skip_multiline = False
    
    for lineno, raw_line in enumerate(file) :
        # parse comment
        if ';' in raw_line :
            line, comment = raw_line.split(';', 1)
        else :
            line = raw_line
            comment = None

        # XXX: handle multi-line statements...
        # start
        if '(' in line :
            skip_multiline = True
            
            log.warn("%s:%d: Start of multi-line statement: %s", file.name, lineno, raw_line)

        # end?
        if ')' in line :
            skip_multiline = False
            
            log.warn("%s:%d: End of multi-line statement: %s", file.name, lineno, raw_line)
            
            continue

        elif skip_multiline :
            log.warn("%s:%d: Multi-line statement: %s", file.name, lineno, raw_line)

            continue
        
        # parse
        line = ZoneLine.parse(file.name, lineno, raw_line, **opts)

        if not line.data :
            log.debug("%s: skip empty line: %s", line, raw_line)

            continue

        elif line.data.startswith('$') :
            # control record
            type = line.parts[0]

            if type == '$ORIGIN':
                # update
                origin = line.parts[1]
                
                log.info("%s: origin: %s", line, origin)
            
            elif type == '$GENERATE':
                # process...
                log.info("%s: generate: %s", line, line.parts)

                for record in process_generate(line, origin, line.parts[1:]) :
                    yield record

            else :
                log.warning("%s: skip control record: %s", line, line.data)
            
            # XXX: passthrough!
            continue

        # normal record?
        record = ZoneRecord.parse(line, origin=origin)

        if record :
            yield record

        else :
            # unknown
            log.warning("%s: skip unknown line: %s", line, line.data)

def check_zone_hosts (zone, whitelist=None, whitelist_types=set(['TXT'])) :
    """
        Parse host/IP pairs from the zone, and verify that they are unique.

        As an exception, names listed in the given whitelist may have multiple IPs.
    """

    by_name = {}
    by_ip = {}

    fail = None

    last_name = None

    for r in zone :
        name = r.name or last_name

        name = (r.origin, name)

        # name
        if r.type not in whitelist_types :
            if name not in by_name :
                by_name[name] = r

            elif r.name in whitelist :
                log.debug("Duplicate whitelist entry: %s", r)

            else :
                # fail!
                log.warn("%s: Duplicate name: %s <-> %s", r.line, r, by_name[name])
                fail = True

        # ip
        if r.type == 'A' :
            ip, = r.data

            if ip not in by_ip :
                by_ip[ip] = r

            else :
                # fail!
                log.warn("%s: Duplicate IP: %s <-> %s", r.line, r, by_ip[ip])
                fail = True

    return fail

def process_zone_forwards (zone, txt=False, mx=False) :
    """
        Process zone data -> forward zone data.
    """

    for r in zone :
        yield r

        if r.type == 'A' :
            if txt :
                # comment?
                comment = r.line.comment

                if comment :
                    yield TXTRecord(None, comment, ttl=r.ttl)

           
            # XXX: RP, do we need it?

            if mx :
                # XXX: is this a good idea?
                yield ZoneRecord(None, 'MX', [10, mx], ttl=r.ttl)

def process_zone_meta (zone, ignore=None) :
    """
        Process zone metadata -> output.
    """
    
    TIMESTAMP_FORMAT='%Y/%m/%d'
    
    for r in zone :
        if ignore and r.name in ignore :
            # skip
            log.debug("Ignore record: %s", r)
            continue

        # for hosts..
        if r.type == 'A' :
            # timestamp?
            timestamp = r.line.timestamp

            if timestamp :
                yield TXTRecord(r.name, timestamp.strftime(TIMESTAMP_FORMAT), ttl=r.ttl)
     
def reverse_ipv4 (ip) :
    """
        Return in-addr.arpa reverse for given IPv4 prefix.
    """

    # parse
    octets = tuple(int(part) for part in ip.split('.'))

    for octet in octets :
        assert 0 <= octet <= 255

    return '.'.join([str(octet) for octet in reversed(octets)] + ['in-addr', 'arpa'])

def reverse_ipv6 (ip6) :
    """
        Return ip6.arpa reverse for given IPv6 prefix.
    """

    parts = [int(part, 16) for part in ip6.split(':')]
    parts = ['{0:04x}'.format(part) for part in parts]
    parts = ''.join(parts)

    return '.'.join(tuple(reversed(parts)) + ( 'ip6', 'arpa'))

def fqdn (*parts) :
    fqdn = '.'.join(parts)
    
    # we may be given an fqdn in parts
    if not fqdn.endswith('.') :
        fqdn += '.'
    
    return fqdn

def process_zone_reverse (zone, origin, domain) :
    """
        Process zone data -> reverse zone data.
    """

    name = None

    for r in zone :
        # keep name from previous..
        if r.name :
            name = r.name

        if r.type == 'A' :
            ip, = r.data
            ptr = reverse_ipv4(ip)

        elif r.type == 'AAAA' :
            ip, = r.data
            ptr = reverse_ipv6(ip)
            
        else :
            continue

        # verify
        if zone and ptr.endswith(origin) :
            ptr = ptr[:-(len(origin) + 1)]

        else :
            log.warning("Reverse does not match zone origin, skipping: (%s) -> %s <-> %s", ip, ptr, origin)
            continue

        # domain to use
        host_domain = r.origin or domain
        host_fqdn = fqdn(name, host_domain)

        yield ZoneRecord(ptr, 'PTR', [host_fqdn])

def write_zone_records (file, zone) :
    for r in zone :
        file.write(r.build_line() + u'\n')

def open_file (path, mode, charset) :
    """
        Open unicode-enabled file from path, with - using stdio.
    """

    if path == '-' :
        # use stdin/out based on mode
        stream, func = {
            'r':    (sys.stdin, codecs.getreader),
            'w':    (sys.stdout, codecs.getwriter),
        }[mode[0]]

        # wrap
        return func(charset)(stream)

    else :
        # open
        return codecs.open(path, mode, charset)

def main (argv) :
    global options
    
    options, args = parse_options(argv)

    if options.doctest :
        import doctest
        fail, total = doctest.testmod()
        return fail

    if args :
        # open files
        input_files = [open_file(path, 'r', options.input_charset) for path in args]

    else :
        # default to stdout
        input_files = [open_file('-', 'r', options.input_charset)]
   
    # process zone data
    zone = []

    for file in input_files :
        log.info("Reading zone: %s", file)

        zone += list(parse_zone_records(file, 
            line_timestamp_prefix   = options.input_line_date,
        ))

    # check?
    if options.check_hosts :
        whitelist = set(options.check_exempt)

        log.debug("checking hosts; whitelist=%r", whitelist)

        if check_zone_hosts(zone, whitelist=whitelist) :
            log.warn("Hosts check failed")
            return 2

        else :
            log.info("Hosts check OK")

    # output file
    output = open_file(options.output, 'w', options.output_charset)

    if options.forward_zone :
        log.info("Write forward zone: %s", output)

        zone = list(process_zone_forwards(zone, txt=options.forward_txt, mx=options.forward_mx))

    elif options.meta_zone :
        log.info("Write metadata zone: %s", output)

        if not options.input_line_date :
            log.error("--meta-zone requires --input-line-date")
            return 1

        zone = list(process_zone_meta(zone, ignore=set(options.meta_ignore)))

    elif options.reverse_zone :
        if ':' in options.reverse_zone :
            # IPv6
            origin = reverse_ipv6(options.reverse_zone)

        else :
            # IPv4
            origin = reverse_ipv4(options.reverse_zone)

        domain = options.reverse_domain

        if not domain :
            log.error("--reverse-zone requires --reverse-domain")
            return 1

        zone = list(process_zone_reverse(zone, origin=origin, domain=domain))

    elif options.check_hosts :
        # we only did that, done
        return 0

    else :
        log.warn("Nothing to do")
        return 1

    write_zone_records(output, zone)

    return 0

if __name__ == '__main__':
    import sys

    sys.exit(main(sys.argv))