pvl/dns/zone.py
author Tero Marttila <terom@paivola.fi>
Sun, 22 Dec 2013 19:03:57 +0200
changeset 336 edaa5d0aa57d
parent 323 9b3cbd8687eb
child 378 3fed153a1fe6
permissions -rw-r--r--
version 0.6.1: pvl.hosts forward/reverse delegation, and include= support
#!/usr/bin/env python

"""
    Process zonefiles.
"""

import codecs
import datetime
import logging
import math
import os.path

log = logging.getLogger('pvl.dns.zone')

class ZoneError (Exception) :
    pass

class ZoneLineError (ZoneError) :
    """
        ZoneLine-related error
    """

    def __init__ (self, line, msg, *args, **kwargs) :
        super(ZoneLineError, self).__init__("%s: %s" % (line, msg.format(*args, **kwargs)))

class ZoneLine (object) :
    """
        A line parsed from a zonefile.
    """

    @classmethod
    def load (cls, file, ttl=None, origin=None, expand_generate=False, expand_include=False, **opts) :
        """
            Parse ZoneLine, ZoneRecord items from the given zonefile.
        """

        name = None

        for line in cls.parse(file, **opts) :
            if not line.parts :
                log.debug("%s: skip empty line", line)

            elif line.line.startswith('$') :
                # control record
                directive = ZoneDirective.parse(line,
                        origin      = origin,
                        comment     = line.comment,
                )

                if directive.directive == 'ORIGIN' :
                    # update
                    origin, = directive.arguments
                    
                    log.info("%s: origin: %s", line, origin)
                    
                    yield line, None

                elif directive.directive == 'TTL' :
                    ttl, = directive.arguments
                    
                    log.info("%s: ttl: %s", line, ttl)
                    
                    yield line, None
                
                elif directive.directive == 'GENERATE' :
                    if expand_generate :
                        # process...
                        log.info("%s: generate: %s", line, directive.arguments)

                        for record in process_generate(line, origin, directive.arguments) :
                            yield line, record
                    else :
                        yield line, None

                elif directive.directive == 'INCLUDE' :
                    if expand_include :
                        include, = directive.arguments

                        path = os.path.join(os.path.dirname(file.name), include)
                        
                        log.info("%s: include: %s: %s", line, include, path)

                        for record in cls.load(open(path)) :
                            yield line, record
                    else :
                        yield line, None

                else :
                    log.warn("%s: skip unknown control record: %r", line, directive)
                    yield line, None
                
            else :
                # normal record?
                record = ZoneRecord.parse(line,
                        name    = name,
                        origin  = origin,
                        ttl     = ttl,
                        comment = line.comment,
                )

                if record :
                    yield line, record
                    
                    # keep name across lines
                    name = record.name

                else :
                    # unknown
                    log.warning("%s: skip unknown line: %s", line, line.line)

                    yield line, None
     

    @classmethod
    def parse (cls, file, filename=None, line_timestamp_prefix=None) :
        """
            Yield ZoneLines lexed from a file.
        """
        
        if not filename :
            filename = file.name
        
        multiline_start = None
        multiline_parts = None
        
        for lineno, raw_line in enumerate(file) :
            raw_line = raw_line.rstrip('\n')

            # possible mtime prefix for line
            timestamp = None

            if line_timestamp_prefix :
                if ': ' not in raw_line :
                    raise ZoneError("%s:%d: Missing timestamp prefix: %s" % (filename, lineno, raw_line))

                # split prefix
                prefix, raw_line = raw_line.split(': ', 1)

                # parse it out
                timestamp = datetime.datetime.strptime(prefix, cls.PARSE_DATETIME_FORMAT)

                log.debug("%s:%d: ts=%r", filename, lineno, ts)
            
            log.debug("%s:%d: %s", filename, lineno, raw_line)
            
            # capture indent from raw line
            indent = raw_line.startswith(' ') or raw_line.startswith('\t')
            line = raw_line.strip()

            # parse comment
            if ';' in line:
                line, comment = line.split(';', 1)

                line = line.strip()
                comment = comment.strip()
            else :
                comment = None
           
            log.debug("%s:%d: indent=%r, line=%r, comment=%r", filename, lineno, indent, line, comment)

            # split (quoted) fields
            if '"' in line :
                pre, data, post = line.split('"', 2)
                parts = pre.split() + [data] + post.split()
               
            else :
                parts = line.split()

            # handle multi-line statements...
            if '(' in parts :
                assert not multiline_start

                log.debug("%s:%d: Start of multi-line statement: %s", filename, lineno, line)

                multiline_start = (lineno, timestamp, indent, comment)
                multiline_line = raw_line
                multiline_parts = []

            if multiline_start:
                log.debug("%s:%d: Multi-line statement: %s", filename, lineno, line)
                
                # XXX: some better way to do this
                multiline_parts.extend([part for part in parts if part not in set('()')])
                multiline_line += raw_line

            if ')' in parts :
                assert multiline_start

                log.debug("%s:%d: End of multi-line statement: %s", filename, lineno, line)
                
                lineno, timestamp, indent, comment = multiline_start
                raw_line = multiline_line
                parts = multiline_parts

                multiline_start = multiline_line = multiline_parts = None
        
            # parse
            if multiline_start:
                pass
            else:
                yield ZoneLine(filename, lineno, raw_line, indent, parts, comment, timestamp=timestamp)

    file = None
    lineno = None

    # data
    indent = None # was the line indented?
    parts = None # split line fields

    # optional
    timestamp = None
    comment = None

    PARSE_DATETIME_FORMAT = '%Y-%m-%d'
    
    def __init__ (self, file, lineno, line, indent, parts, comment=None, timestamp=None) :
        # source
        self.file = file
        self.lineno = lineno
        self.line = line
        
        # parse data
        self.indent = indent
        self.parts = parts
        
        # metadata
        self.timestamp = timestamp
        self.comment = comment

    def __unicode__ (self) :
        return u"{indent}{parts}".format(
                indent      = u"\t" if self.indent else '',
                parts       = u'\t'.join(self.parts),
        )

    def __repr__ (self) :
        return "{file}:{lineno}".format(file=self.file, lineno=self.lineno)

class ZoneDirective (object) :
    """
        An $DIRECTIVE in a zonefile.
    """
    
    # context
    line    = None
    origin  = None

    # fields
    directive   = None
    arguments   = None
    
    @classmethod
    def parse (cls, line, **opts) :
        # control record
        args = list(line.parts)
        directive = args[0][1:]
        arguments = args[1:]

        return cls.build(line, directive, *arguments, **opts)

    @classmethod
    def build (cls, line, directive, *arguments, **opts) :
        return cls(directive, arguments,
                line        = line,
                **opts
        )

    def __init__ (self, directive, arguments, line=None, origin=None, comment=None) :
        self.directive  = directive
        self.arguments  = arguments

        self.line = line
        self.origin = origin
        self.comment = comment

    def __unicode__ (self) :
        """
            Construct a zonefile-format line...
        """

        if self.comment :
            comment = '\t; ' + self.comment
        else :
            comment = ''
            
        return u"${directive}\t{arguments}{comment}".format(
                directive   = self.directive,
                arguments   = '\t'.join(self.arguments),
                comment     = comment,
        )

    def __repr__ (self) :
        return '%s(%s)' % (self.__class__.__name__, ', '.join(repr(arg) for arg in (
            (self.directive) + tuple(self.arguments)
        )))


class ZoneRecord (object) :
    """
        A record from a zonefile.
    """
    
    # context
    line = None # the underlying line
    origin = None # possible $ORIGIN context

    # record fields
    name = None
    ttl = None  # optional
    cls = None  # optional
    type = None
    data = None # list of data fields
    
    @classmethod
    def load (cls, file, **opts) :
        """
            Yield ZoneRecords from a file.
        """

        for line, record in ZoneLine.load(file, **opts) :
            if record :
                yield record
            else :
                log.warn("%s: unparsed line: %s", file.name, line)

    @classmethod
    def parse (cls, line, name=None, parts=None, ttl=None, **opts) :
        """
            Build a ZoneRecord from a ZoneLine.

                name        - default for name, if continuing previous line

            Return: (name, ZoneRecord)
        """

        if parts is None :
            parts = list(line.parts)

        if not parts :
            # skip
            return
        
        if line.indent :
            # indented lines keep name from previous record
            pass

        else :
            name = parts.pop(0)
        
        if len(parts) < 2 :
            raise ZoneLineError(line, "Too few parts to parse: {0!r}", line.data)

        # parse ttl/cls/type
        _cls = None

        if parts and parts[0][0].isdigit() :
            ttl = parts.pop(0)

        if parts and parts[0].upper() in ('IN', 'CH') :
            _cls = parts.pop(0)

        # always have type
        type = parts.pop(0)

        # remaining parts are data
        data = parts

        log.debug("  ttl=%r, cls=%r, type=%r, data=%r", ttl, _cls, type, data)

        return cls.build(line, name, ttl, _cls, type, data, **opts)
    
    @classmethod
    def build (cls, line, name, ttl, _cls, type, data, **opts) :
        return cls(name, type, data,
            ttl     = ttl,
            cls     = _cls,
            line    = line,
            **opts
        )

    @classmethod
    def A (cls, name, ip4, **opts) :
        return cls(str(name), 'A', [str(ip4)], **opts)

    @classmethod
    def AAAA (cls, name, ip6, **opts) :
        return cls(str(name), 'AAAA', [str(ip6)], **opts)

    @classmethod
    def CNAME (cls, name, host, **opts) :
        return cls(str(name), 'CNAME', [str(host)], **opts)

    @classmethod
    def TXT (cls, name, text, **opts) :
        return cls(str(name), 'TXT',
            [u'"{0}"'.format(text.replace('"', '\\"'))], 
            **opts
        )

    @classmethod
    def PTR (cls, name, ptr, **opts) :
        return cls(str(name), 'PTR', [str(ptr)], **opts)

    @classmethod
    def MX (cls, name, priority, mx, **opts) :
        return cls(str(name), 'MX', [int(priority), str(mx)], **opts)

    def __init__ (self, name, type, data, ttl=None, cls=None, line=None, origin=None, comment=None) :
        self.name = name
        self.type = type
        self.data = data
        
        self.ttl = ttl
        self.cls = cls
        
        self.line = line
        self.origin = origin
        self.comment = comment

    def __unicode__ (self) :
        """
            Construct a zonefile-format line..."
        """

        if self.comment :
            comment = '\t; ' + self._comment
        else :
            comment = ''
            
        return u"{name:25} {ttl:4} {cls:2} {type:5} {data}{comment}".format(
                name    = self.name or '',
                ttl     = self.ttl or '',
                cls     = self.cls or '',
                type    = self.type,
                data    = ' '.join(unicode(data) for data in self.data),
                comment = comment,
        )

    def __repr__ (self) :
        return '%s(%s)' % (self.__class__.__name__, ', '.join(repr(arg) for arg in (
            self.name, self.type, self.data
        )))

class SOA (ZoneRecord) :
    @classmethod
    def build (cls, line, name, ttl, _cls, type, data, **opts) :
        assert name == '@'

        return cls(*data,
            ttl     = ttl,
            cls     = cls,
            line    = line,
            **opts
        )

    def __init__ (self, master, contact, serial, refresh, retry, expire, nxttl, **opts) :
        super(SOA, self).__init__('@', 'SOA',
            [master, contact, serial, refresh, retry, expire, nxttl],
            **opts
        )

        self.master = master
        self.contact = contact
        self.serial = serial
        self.refresh = refresh
        self.retry = retry
        self.expire = expire
        self.nxttl = nxttl

class OffsetValue (object) :
    """
        Magic for $GENERATE offsets.

        >>> OffsetValue(0)[1]
        1
        >>> OffsetValue(10)[5]
        15
    """

    def __init__ (self, base) :
        self.base = base

    def __getitem__ (self, offset) :
        return self.base + offset

def parse_generate_field (field, line=None) :
    """
        Parse a $GENERATE lhs/rhs field:
            $
            ${<offset>[,<width>[,<base>]]}
            \$
            $$

        Returns a wrapper that builds the field-value when called with the index.
        
        >>> parse_generate_field("foo")(1)
        'foo'
        >>> parse_generate_field("foo-$")(1)
        'foo-1'
        >>> parse_generate_field("foo-$$")(1)
        'foo-$'
        >>> parse_generate_field("\$")(1)
        '$'
        >>> parse_generate_field("10.0.0.${100}")(1)
        '10.0.0.101'
        >>> parse_generate_field("foo-${0,2,d}")(1)
        'foo-01'

    """

    input = field
    expr = []

    while '$' in field :
        # defaults
        offset = 0
        width = 0
        base = 'd'
        escape = False

        # different forms
        if '${' in field :
            pre, body = field.split('${', 1)
            body, post = body.split('}', 1)

            # parse body
            parts = body.split(',')

            # offset
            offset = int(parts.pop(0))

            # width
            if parts :
                width = int(parts.pop(0))

            # base
            if parts :
                base = parts.pop(0)
            
            if parts:
                # fail
                raise ZoneLineError(line, "extra data in ${...} body: {0!r}", parts)

        elif '$$' in field :
            pre, post = field.split('$$', 1)
            escape = True

        elif '\\$' in field :
            pre, post = field.split('\\$', 1)
            escape = True

        else :
            pre, post = field.split('$', 1)
        
        expr.append(pre)

        if escape :
            expr.append('$')

        else :
            # meta-format
            fmt = '{value[%d]:0%d%s}' % (offset, width, base)

            log.debug("field=%r -> pre=%r, fmt=%r, post=%r", field, expr, fmt, post)

            expr.append(fmt)

        field = post

    # final
    if field :
        expr.append(field)
    
    # combine
    expr = ''.join(expr)

    log.debug("%s: %s", input, expr)

    # processed
    def value_func (value) :
        # magic wrapper to implement offsets
        return expr.format(value=OffsetValue(value))
    
    return value_func

def parse_generate_range (field) :
    """
        Parse a <start>-<stop>[/<step>] field
    """

    if '/' in field :
        field, step = field.split('/')
        step = int(step)
    else :
        step = 1

    start, stop = field.split('-')
    start = int(start)
    stop = int(stop)

    log.debug("  range: start=%r, stop=%r, step=%r", start, stop, step)

    # inclusive
    return range(start, stop + 1, step)

def process_generate (line, origin, parts) :
    """
        Process a 
            $GENERATE <start>-<stop>[/<step>] lhs [ttl] [class] type rhs [comment]
        directive into a series of ZoneResource's.
    """

    parts = list(parts)

    range = parse_generate_range(parts.pop(0))

    lhs_func = parse_generate_field(parts.pop(0), line=line)
    rhs_func = parse_generate_field(parts.pop(-1), line=line)
    body = parts

    for i in range :
        # build
        parts = [lhs_func(i)] + body + [rhs_func(i)]

        log.debug(" %03d: %r", i, parts)

        # parse
        yield ZoneRecord.parse(line, parts=parts, origin=origin)

   
def reverse_ipv4 (ip) :
    """
        Return in-addr.arpa reverse for given IPv4 prefix.
    """

    # parse
    octets = tuple(int(part) for part in ip.split('.'))

    for octet in octets :
        assert 0 <= octet <= 255

    return '.'.join([str(octet) for octet in reversed(octets)] + ['in-addr', 'arpa'])

def reverse_ipv6 (ip6) :
    """
        Return ip6.arpa reverse for given IPv6 prefix.
    """

    parts = [int(part, 16) for part in ip6.split(':')]
    parts = ['{0:04x}'.format(part) for part in parts]
    parts = ''.join(parts)

    return '.'.join(tuple(reversed(parts)) + ( 'ip6', 'arpa'))

# TODO: support fqdns in parts
def join (*parts) :
    """
        Join a domain name from labels.
    """

    return '.'.join(str(part) for part in parts)

def fqdn (*parts) :
    """
        Return an FQND from parts, ending in .
    """

    fqdn = join(*parts)

    # we may be given an fqdn in parts
    if not fqdn.endswith('.') :
        fqdn += '.'
    
    return fqdn

def reverse_label (prefix, address) :
    """
        Determine the correct label for the given IP address within the reverse zone for the given prefix.

        This includes all suffix octets (partially) covered by the prefix.
    """

    assert prefix.version == address.version
    
    hostbits = prefix.max_prefixlen - prefix.prefixlen

    if prefix.version == 4 :
        # pack into octets
        octets = [ord(x) for x in address.packed]

        # take the suffix
        octets = octets[-int(math.ceil(hostbits / 8.0)):]
        
        # reverse in decimal
        return '.'.join(reversed(["{0:d}".format(x) for x in octets]))

    elif prefix.version == 6 :
        # pack into nibbles
        nibbles = [((ord(x) >> 4) & 0xf, ord(x) & 0xf) for x in address.packed]
        nibbles = [nibble for nibblepair in nibbles for nibble in nibblepair]

        # take the suffix
        nibbles = nibbles[-(hostbits / 4):]
        
        # reverse in hex
        return '.'.join(reversed(["{0:x}".format(x) for x in nibbles]))

    else :
        raise ValueError("unsupported address version: %s" % (prefix, ))