pvl/dns/zone.py
author Tero Marttila <terom@paivola.fi>
Wed, 11 Sep 2013 14:06:16 +0300
changeset 251 d250f200dd7e
parent 250 65f0272ce458
child 252 0ea4450fdd40
permissions -rw-r--r--
fix process_generate for ZoneRecord.build
#!/usr/bin/env python

"""
    Process zonefiles.
"""

import codecs
from datetime import datetime
import logging

log = logging.getLogger('pvl.dns.zone')

class ZoneError (Exception) :
    pass

class ZoneLineError (ZoneError) :
    """
        ZoneLine-related error
    """

    def __init__ (self, line, msg, *args, **kwargs) :
        super(ZoneLineError, self).__init__("%s: %s" % (line, msg.format(*args, **kwargs)))

class ZoneLine (object) :
    """
        A line parsed from a zonefile.
    """

    @classmethod
    def load (cls, file, filename=None, line_timestamp_prefix=None) :
        """
            Yield ZoneLines lexed from a file.
        """
        
        if not filename :
            filename = file.name
        
        multiline_start = None
        multiline_parts = None
        
        for lineno, raw_line in enumerate(file) :
            # possible mtime prefix for line
            timestamp = None

            if line_timestamp_prefix :
                if ': ' not in raw_line :
                    raise ZoneError("%s:%d: Missing timestamp prefix: %s" % (filename, lineno, raw_line))

                # split prefix
                prefix, raw_line = raw_line.split(': ', 1)

                # parse it out
                timestamp = datetime.strptime(prefix, cls.PARSE_DATETIME_FORMAT)

                log.debug("%s:%d: ts=%r", filename, lineno, ts)
            
            log.debug("%s:%d: %s", filename, lineno, raw_line)
            
            # capture indent from raw line
            indent = raw_line.startswith(' ') or raw_line.startswith('\t')
            line = raw_line.strip()

            # parse comment
            if ';' in line:
                line, comment = line.split(';', 1)

                line = line.strip()
                comment = comment.strip()
            else :
                comment = None
           
            log.debug("%s:%d: indent=%r, line=%r, comment=%r", filename, lineno, indent, line, comment)

            # split (quoted) fields
            if '"' in line :
                pre, data, post = line.split('"', 2)
                parts = pre.split() + [data] + post.split()
               
            else :
                parts = line.split()

            # handle multi-line statements...
            if '(' in parts :
                assert not multiline_start

                log.warn("%s:%d: Start of multi-line statement: %s", filename, lineno, line)

                multiline_start = (lineno, timestamp, indent, comment)
                multiline_line = raw_line
                multiline_parts = []

            if multiline_start:
                log.warn("%s:%d: Multi-line statement: %s", filename, lineno, line)
                
                # XXX: some better way to do this
                multiline_parts.extend([part for part in parts if part not in set('()')])
                multiline_line += raw_line

            if ')' in parts :
                assert multiline_start

                log.warn("%s:%d: End of multi-line statement: %s", filename, lineno, line)
                
                lineno, timestamp, indent, comment = multiline_start
                raw_line = multiline_line
                parts = multiline_parts

                multiline_start = multiline_line = multiline_parts = None
        
            # parse
            if multiline_start:
                pass
            else:
                yield ZoneLine(filename, lineno, raw_line, indent, parts, comment, timestamp=timestamp)

    file = None
    lineno = None

    # data
    indent = None # was the line indented?
    parts = None # split line fields

    # optional
    timestamp = None
    comment = None

    PARSE_DATETIME_FORMAT = '%Y-%m-%d'
    
    def __init__ (self, file, lineno, line, indent, parts, comment=None, timestamp=None) :
        # source
        self.file = file
        self.lineno = lineno
        self.line = line
        
        # parse data
        self.indent = indent
        self.parts = parts
        
        # metadata
        self.timestamp = timestamp
        self.comment = comment

    def __str__ (self) :
        return "{file}:{lineno}".format(file=self.file, lineno=self.lineno)

class ZoneRecord (object) :
    """
        A record from a zonefile.
    """
    
    # context
    line = None # the underlying line
    origin = None # possible $ORIGIN context

    # record fields
    name = None
    ttl = None  # optional
    cls = None  # optional
    type = None
    data = None # list of data fields
    
    @classmethod
    def load (cls, file, origin=None, **opts) :
        """
            Parse ZoneRecord items from the given zonefile, ignoring non-record lines.
        """

        for line in ZoneLine.load(file, **opts):
            if not line.parts :
                log.debug("%s: skip empty line", line)

            elif line.line.startswith('$') :
                # control record
                type = line.parts[0]

                if type == '$ORIGIN':
                    # update
                    origin = line.parts[1]
                    
                    log.info("%s: origin: %s", line, origin)
                
                elif type == '$GENERATE':
                    # process...
                    log.info("%s: generate: %s", line, line.parts)

                    for record in process_generate(line, origin, line.parts[1:]) :
                        yield record

                else :
                    log.warning("%s: skip control record: %s", line, line.line)
                
            else :
                # normal record?
                record = ZoneRecord.build(line, origin=origin)

                if record :
                    yield record

                else :
                    # unknown
                    log.warning("%s: skip unknown line: %s", line, line.line)
     
    @classmethod
    def build (cls, line, parts=None, origin=None) :
        """
            Build a ZoneRecord from a ZoneLine.
        """

        if parts is None :
            parts = list(line.parts)

        if not parts :
            # skip
            return
        
        # XXX: indented lines keep name from previous record
        if line.indent :
            name = None

        else :
            name = parts.pop(0)
        
        log.debug("  name=%r, origin=%r", name, origin)

        if len(parts) < 2 :
            raise ZoneLineError(line, "Too few parts to parse: {0!r}", line.data)

        # parse ttl/cls/type
        ttl = _cls = None

        if parts and parts[0][0].isdigit() :
            ttl = parts.pop(0)

        if parts and parts[0].upper() in ('IN', 'CH') :
            _cls = parts.pop(0)

        # always have type
        type = parts.pop(0)

        # remaining parts are data
        data = parts

        log.debug("  ttl=%r, cls=%r, type=%r, data=%r", ttl, _cls, type, data)

        return cls(name, type, data,
            origin  = origin,
            ttl     = ttl,
            cls     = _cls,
            line    = line,
        )

    @classmethod
    def TXT (cls, name, text, **opts) :
        return cls(name, 'TXT',
            [u'"{0}"'.format(text.replace('"', '\\"'))], 
            **opts
        )

    @classmethod
    def PTR (cls, name, ptr, **opts) :
        return cls(name, 'PTR', [ptr], **opts)

    @classmethod
    def MX (cls, name, priority, mx, **opts) :
        return cls(name, 'MX', [priority, mx], **opts)

    def __init__ (self, name, type, data, origin=None, ttl=None, cls=None, line=None, comment=None) :
        self.name = name
        self.type = type
        self.data = data
        
        self.ttl = ttl
        self.cls = cls
        
        self.origin = origin
        self.line = line

        self.comment = comment

    def __unicode__ (self) :
        """
            Construct a zonefile-format line..."
        """

        if self.comment :
            comment = '\t; ' + self._comment
        else :
            comment = ''
            
        return u"{name:25} {ttl:4} {cls:2} {type:5} {data}{comment}".format(
                name    = self.name or '',
                ttl     = self.ttl or '',
                cls     = self.cls or '',
                type    = self.type,
                data    = ' '.join(unicode(data) for data in self.data),
                comment = comment,
        )

    def __repr__ (self) :
        return '%s(%s)' % (self.__class__.__name__, ', '.join(repr(arg) for arg in (
            self.name, self.type, self.data
        )))

class OffsetValue (object) :
    """
        Magic for $GENERATE offsets.

        >>> OffsetValue(0)[1]
        1
        >>> OffsetValue(10)[5]
        15
    """

    def __init__ (self, base) :
        self.base = base

    def __getitem__ (self, offset) :
        return self.base + offset

def parse_generate_field (line, field) :
    """
        Parse a $GENERATE lhs/rhs field:
            $
            ${<offset>[,<width>[,<base>]]}
            \$
            $$

        Returns a wrapper that builds the field-value when called with the index.
        
        >>> parse_generate_field(None, "foo")(1)
        'foo'
        >>> parse_generate_field(None, "foo-$")(1)
        'foo-1'
        >>> parse_generate_field(None, "foo-$$")(1)
        'foo-$'
        >>> parse_generate_field(None, "\$")(1)
        '$'
        >>> parse_generate_field(None, "10.0.0.${100}")(1)
        '10.0.0.101'
        >>> parse_generate_field(None, "foo-${0,2,d}")(1)
        'foo-01'

    """

    input = field
    expr = []

    while '$' in field :
        # defaults
        offset = 0
        width = 0
        base = 'd'
        escape = False

        # different forms
        if '${' in field :
            pre, body = field.split('${', 1)
            body, post = body.split('}', 1)

            # parse body
            parts = body.split(',')

            # offset
            offset = int(parts.pop(0))

            # width
            if parts :
                width = int(parts.pop(0))

            # base
            if parts :
                base = parts.pop(0)
            
            if parts:
                # fail
                raise ZoneLineError(line, "extra data in ${...} body: {0!r}", parts)

        elif '$$' in field :
            pre, post = field.split('$$', 1)
            escape = True

        elif '\\$' in field :
            pre, post = field.split('\\$', 1)
            escape = True

        else :
            pre, post = field.split('$', 1)
        
        expr.append(pre)

        if escape :
            expr.append('$')

        else :
            # meta-format
            fmt = '{value[%d]:0%d%s}' % (offset, width, base)

            log.debug("field=%r -> pre=%r, fmt=%r, post=%r", field, expr, fmt, post)

            expr.append(fmt)

        field = post

    # final
    if field :
        expr.append(field)
    
    # combine
    expr = ''.join(expr)

    log.debug("%s: %s", input, expr)

    # processed
    def value_func (value) :
        # magic wrapper to implement offsets
        return expr.format(value=OffsetValue(value))
    
    return value_func

def process_generate (line, origin, parts) :
    """
        Process a 
            $GENERATE <start>-<stop>[/<step>] lhs [ttl] [class] type rhs [comment]
        directive into a series of ZoneResource's.
    """

    range = parts.pop(0)

    # parse range
    if '/' in range :
        range, step = range.split('/')
        step = int(step)
    else :
        step = 1

    start, stop = range.split('-')
    start = int(start)
    stop = int(stop)

    log.debug("  range: start=%r, stop=%r, step=%r", start, stop, step)

    # inclusive
    range = xrange(start, stop + 1, step)

    lhs_func = parse_generate_field(line, parts.pop(0))
    rhs_func = parse_generate_field(line, parts.pop(-1))
    body = parts

    for i in range :
        # build
        parts = [lhs_func(i)] + body + [rhs_func(i)]

        log.debug(" %03d: %r", i, parts)

        # parse
        yield ZoneRecord.build(line, parts=parts, origin=origin)

   
def reverse_ipv4 (ip) :
    """
        Return in-addr.arpa reverse for given IPv4 prefix.
    """

    # parse
    octets = tuple(int(part) for part in ip.split('.'))

    for octet in octets :
        assert 0 <= octet <= 255

    return '.'.join([str(octet) for octet in reversed(octets)] + ['in-addr', 'arpa'])

def reverse_ipv6 (ip6) :
    """
        Return ip6.arpa reverse for given IPv6 prefix.
    """

    parts = [int(part, 16) for part in ip6.split(':')]
    parts = ['{0:04x}'.format(part) for part in parts]
    parts = ''.join(parts)

    return '.'.join(tuple(reversed(parts)) + ( 'ip6', 'arpa'))

def fqdn (*parts) :
    fqdn = '.'.join(parts)
    
    # we may be given an fqdn in parts
    if not fqdn.endswith('.') :
        fqdn += '.'
    
    return fqdn