pvl/dns/zone.py
author Tero Marttila <terom@paivola.fi>
Wed, 11 Sep 2013 13:34:23 +0300
changeset 248 cce9cf4933ca
parent 247 08a63738f2d1
child 250 65f0272ce458
permissions -rw-r--r--
pvl.dns.zone: more multi-line support in the parser..
#!/usr/bin/env python

"""
    Process zonefiles.
"""

import codecs
from datetime import datetime
import logging

log = logging.getLogger('pvl.dns.zone')

class ZoneError (Exception) :
    pass

class ZoneLineError (ZoneError) :
    """
        ZoneLine-related error
    """

    def __init__ (self, line, msg, *args, **kwargs) :
        super(ZoneLineError, self).__init__("%s: %s" % (line, msg.format(*args, **kwargs)))

class ZoneLine (object) :
    """
        A line in a zonefile.
    """

    file = None
    lineno = None

    # data
    indent = None # was the line indented?
    parts = None # split line fields

    # optional
    timestamp = None
    comment = None

    PARSE_DATETIME_FORMAT = '%Y-%m-%d'

    def __init__ (self, file, lineno, line, indent, parts, comment=None, timestamp=None) :
        # source
        self.file = file
        self.lineno = lineno
        self.line = line
        
        # parse data
        self.indent = indent
        self.parts = parts
        
        # metadata
        self.timestamp = timestamp
        self.comment = comment

    def __str__ (self) :
        return "{file}:{lineno}".format(file=self.file, lineno=self.lineno)

class ZoneRecord (object) :
    """
        A record from a zonefile.
    """

    # the underlying line
    line = None

    # possible $ORIGIN context
    origin = None

    # record fields
    name = None
    type = None

    # list of data fields
    data = None

    # optional
    ttl = None
    cls = None

    @classmethod
    def parse (cls, line, parts=None, origin=None) :
        """
            Parse from ZoneLine. Returns None if there is no record on the line..
        """

        if parts is None :
            parts = list(line.parts)

        if not parts :
            # skip
            return
        
        # XXX: indented lines keep name from previous record
        if line.indent :
            name = None

        else :
            name = parts.pop(0)
        
        log.debug("  name=%r, origin=%r", name, origin)

        if len(parts) < 2 :
            raise ZoneLineError(line, "Too few parts to parse: {0!r}", line.data)

        # parse ttl/cls/type
        ttl = _cls = None

        if parts and parts[0][0].isdigit() :
            ttl = parts.pop(0)

        if parts and parts[0].upper() in ('IN', 'CH') :
            _cls = parts.pop(0)

        # always have type
        type = parts.pop(0)

        # remaining parts are data
        data = parts

        log.debug("  ttl=%r, cls=%r, type=%r, data=%r", ttl, _cls, type, data)

        return cls(name, type, data,
            origin  = origin,
            ttl     = ttl,
            cls     = _cls,
            line    = line,
        )

    @classmethod
    def TXT (cls, name, text, **opts) :
        return cls(name, 'TXT',
            [u'"{0}"'.format(text.replace('"', '\\"'))], 
            **opts
        )

    @classmethod
    def PTR (cls, name, ptr, **opts) :
        return cls(name, 'PTR', [ptr], **opts)

    @classmethod
    def MX (cls, name, priority, mx, **opts) :
        return cls(name, 'MX', [priority, mx], **opts)

    def __init__ (self, name, type, data, origin=None, ttl=None, cls=None, line=None, comment=None) :
        self.name = name
        self.type = type
        self.data = data
        
        self.ttl = ttl
        self.cls = cls
        
        self.origin = origin
        self.line = line

        self.comment = comment

    def __unicode__ (self) :
        """
            Construct a zonefile-format line..."
        """

        if self.comment :
            comment = '\t; ' + self._comment
        else :
            comment = ''
            
        return u"{name:25} {ttl:4} {cls:2} {type:5} {data}{comment}".format(
                name    = self.name or '',
                ttl     = self.ttl or '',
                cls     = self.cls or '',
                type    = self.type,
                data    = ' '.join(unicode(data) for data in self.data),
                comment = comment,
        )

    def __repr__ (self) :
        return '%s(%s)' % (self.__class__.__name__, ', '.join(repr(arg) for arg in (
            self.name, self.type, self.data
        )))

class OffsetValue (object) :
    """
        Magic for $GENERATE offsets.

        >>> OffsetValue(0)[1]
        1
        >>> OffsetValue(10)[5]
        15
    """

    def __init__ (self, base) :
        self.base = base

    def __getitem__ (self, offset) :
        return self.base + offset

def parse_generate_field (line, field) :
    """
        Parse a $GENERATE lhs/rhs field:
            $
            ${<offset>[,<width>[,<base>]]}
            \$
            $$

        Returns a wrapper that builds the field-value when called with the index.
        
        >>> parse_generate_field(None, "foo")(1)
        'foo'
        >>> parse_generate_field(None, "foo-$")(1)
        'foo-1'
        >>> parse_generate_field(None, "foo-$$")(1)
        'foo-$'
        >>> parse_generate_field(None, "\$")(1)
        '$'
        >>> parse_generate_field(None, "10.0.0.${100}")(1)
        '10.0.0.101'
        >>> parse_generate_field(None, "foo-${0,2,d}")(1)
        'foo-01'

    """

    input = field
    expr = []

    while '$' in field :
        # defaults
        offset = 0
        width = 0
        base = 'd'
        escape = False

        # different forms
        if '${' in field :
            pre, body = field.split('${', 1)
            body, post = body.split('}', 1)

            # parse body
            parts = body.split(',')

            # offset
            offset = int(parts.pop(0))

            # width
            if parts :
                width = int(parts.pop(0))

            # base
            if parts :
                base = parts.pop(0)
            
            if parts:
                # fail
                raise ZoneLineError(line, "extra data in ${...} body: {0!r}", parts)

        elif '$$' in field :
            pre, post = field.split('$$', 1)
            escape = True

        elif '\\$' in field :
            pre, post = field.split('\\$', 1)
            escape = True

        else :
            pre, post = field.split('$', 1)
        
        expr.append(pre)

        if escape :
            expr.append('$')

        else :
            # meta-format
            fmt = '{value[%d]:0%d%s}' % (offset, width, base)

            log.debug("field=%r -> pre=%r, fmt=%r, post=%r", field, expr, fmt, post)

            expr.append(fmt)

        field = post

    # final
    if field :
        expr.append(field)
    
    # combine
    expr = ''.join(expr)

    log.debug("%s: %s", input, expr)

    # processed
    def value_func (value) :
        # magic wrapper to implement offsets
        return expr.format(value=OffsetValue(value))
    
    return value_func

def process_generate (line, origin, parts) :
    """
        Process a 
            $GENERATE <start>-<stop>[/<step>] lhs [ttl] [class] type rhs [comment]
        directive into a series of ZoneResource's.
    """

    range = parts.pop(0)

    # parse range
    if '/' in range :
        range, step = range.split('/')
        step = int(step)
    else :
        step = 1

    start, stop = range.split('-')
    start = int(start)
    stop = int(stop)

    log.debug("  range: start=%r, stop=%r, step=%r", start, stop, step)

    # inclusive
    range = xrange(start, stop + 1, step)

    lhs_func = parse_generate_field(line, parts.pop(0))
    rhs_func = parse_generate_field(line, parts.pop(-1))
    body = parts

    for i in range :
        # build
        parts = [lhs_func(i)] + body + [rhs_func(i)]

        log.debug(" %03d: %r", i, parts)

        # parse
        yield ZoneRecord.parse(line, parts=parts, origin=origin)

def parse_zone_lines (file, line_timestamp_prefix=None) :
    """
        Parse ZoneLines from a file.
    """
    
    multiline_start = None
    multiline_parts = None
    
    for lineno, raw_line in enumerate(file) :
        # possible mtime prefix for line
        timestamp = None

        if line_timestamp_prefix :
            if ': ' not in raw_line :
                raise ZoneError("%s:%d: Missing timestamp prefix: %s" % (file.name, lineno, raw_line))

            # split prefix
            prefix, raw_line = raw_line.split(': ', 1)

            # parse it out
            timestamp = datetime.strptime(prefix, cls.PARSE_DATETIME_FORMAT)

            log.debug("%s:%d: ts=%r", file.name, lineno, ts)
        
        log.debug("%s:%d: %s", file.name, lineno, raw_line)
        
        # capture indent from raw line
        indent = raw_line.startswith(' ') or raw_line.startswith('\t')
        line = raw_line.strip()

        # parse comment
        if ';' in line:
            line, comment = line.split(';', 1)

            line = line.strip()
            comment = comment.strip()
        else :
            comment = None
       
        log.debug("%s:%d: indent=%r, line=%r, comment=%r", file.name, lineno, indent, line, comment)

        # parse fields
        if '"' in line :
            pre, data, post = line.split('"', 2)
            parts = pre.split() + [data] + post.split()
           
        else :
            parts = line.split()

        # handle multi-line statements...
        if '(' in parts :
            assert not multiline_start

            log.warn("%s:%d: Start of multi-line statement: %s", file.name, lineno, line)

            multiline_start = (lineno, timestamp, indent, comment)
            multiline_line = raw_line
            multiline_parts = []

        if multiline_start:
            log.warn("%s:%d: Multi-line statement: %s", file.name, lineno, line)
            
            multiline_parts.extend([part for part in parts if part not in set('()')])
            multiline_line += raw_line

        if ')' in parts :
            assert multiline_start

            log.warn("%s:%d: End of multi-line statement: %s", file.name, lineno, line)
            
            lineno, timestamp, indent, comment = multiline_start
            raw_line = multiline_line
            parts = multiline_parts

            multiline_start = multiline_line = multiline_parts = None
    
        # parse
        if multiline_start:
            pass
        else:
            yield ZoneLine(file.name, lineno, raw_line, indent, parts, comment, timestamp=timestamp)

def parse_zone_records (file, origin=None, **opts) :
    """
        Parse ZoneRecord items from the given zonefile, ignoring non-record lines.
    """

    ttl = None
    
    for line in parse_zone_lines(file, **opts):
        if not line.parts :
            log.debug("%s: skip empty line", line)

        elif line.line.startswith('$') :
            # control record
            type = line.parts[0]

            if type == '$ORIGIN':
                # update
                origin = line.parts[1]
                
                log.info("%s: origin: %s", line, origin)
            
            elif type == '$GENERATE':
                # process...
                log.info("%s: generate: %s", line, line.parts)

                for record in process_generate(line, origin, line.parts[1:]) :
                    yield record

            else :
                log.warning("%s: skip control record: %s", line, line.line)
            
        else :
            # normal record?
            record = ZoneRecord.parse(line, origin=origin)

            if record :
                yield record

            else :
                # unknown
                log.warning("%s: skip unknown line: %s", line, line.line)
    
def reverse_ipv4 (ip) :
    """
        Return in-addr.arpa reverse for given IPv4 prefix.
    """

    # parse
    octets = tuple(int(part) for part in ip.split('.'))

    for octet in octets :
        assert 0 <= octet <= 255

    return '.'.join([str(octet) for octet in reversed(octets)] + ['in-addr', 'arpa'])

def reverse_ipv6 (ip6) :
    """
        Return ip6.arpa reverse for given IPv6 prefix.
    """

    parts = [int(part, 16) for part in ip6.split(':')]
    parts = ['{0:04x}'.format(part) for part in parts]
    parts = ''.join(parts)

    return '.'.join(tuple(reversed(parts)) + ( 'ip6', 'arpa'))

def fqdn (*parts) :
    fqdn = '.'.join(parts)
    
    # we may be given an fqdn in parts
    if not fqdn.endswith('.') :
        fqdn += '.'
    
    return fqdn