pvl/dns/zone.py
author Tero Marttila <terom@paivola.fi>
Tue, 10 Mar 2015 00:26:31 +0200
changeset 740 74352351d6f5
parent 647 90a0790adf8a
permissions -rw-r--r--
replace ipaddr with ipaddress
#!/usr/bin/env python

"""
    Process zonefiles.
"""

import codecs
import datetime
import ipaddress
import logging
import math
import os.path
import pvl.dns.labels

log = logging.getLogger('pvl.dns.zone')

class ZoneError (Exception) :
    pass

class ZoneLineError (ZoneError) :
    """
        ZoneLine-related error
    """

    def __init__ (self, line, msg, *args, **kwargs) :
        super(ZoneLineError, self).__init__("%s: %s" % (line, msg.format(*args, **kwargs)))

def zone_quote (field):
    """
        Quote a value for inclusion into TXT record.

        >>> print zone_quote("foo")
        foo
        >>> print zone_quote("foo\\bar")
        "foo\\bar"
        >>> print zone_quote("foo \"bar\" quux")
        "foo \"bar\" quux"
    """

    if any(c.isspace() or c == '\\' for c in field):
        return u'"' + field.replace('\\', '\\\\').replace('"', '\\"') + u'"'
    else:
        return field


class ZoneLine (object) :
    """
        A line parsed from a zonefile.
    """

    @classmethod
    def load (cls, file) :
        """
            Load (ZoneDirective or ZoneRecord) items from the given zonefile.

            Tracks $ORIGIN and $TTL state for ZoneRecords/ZoneDirectives.
        """

        for line in cls.parse(file) :
            if not line.parts:
                log.debug("%s: Skip empty line: %r", line, line)
                continue

            elif line.parts[0].startswith('$'):
                # control directive
                line_type = ZoneDirective

            else :
                # normal record
                line_type = ZoneRecord

            yield line_type.parse(line.parts,
                    line        = line,
                    comment     = line.comment,
            )

    @classmethod
    def parse (cls, file, filename=None):
        """
            Yield ZoneLines lexed from a file.
        """
        
        if not filename :
            filename = file.name
        
        multiline_start = None
        multiline_parts = None
        
        for lineno, raw_line in enumerate(file) :
            raw_line = raw_line.rstrip('\n')

            log.debug("%s:%d: %s", filename, lineno, raw_line)
            
            # capture indent from raw line
            indent = raw_line.startswith(' ') or raw_line.startswith('\t')
            line = raw_line.strip()

            if not line:
                # empty line
                continue

            # parse comment
            if ';' in line:
                line, comment = line.split(';', 1)

                line = line.strip()
                comment = comment.strip()
            else :
                comment = None
           
            log.debug("%s:%d: indent=%r, line=%r, comment=%r", filename, lineno, indent, line, comment)

            # split (quoted) fields
            if indent and not multiline_start:
                parts = ['']
            else:
                parts = []

            if '"' in line :
                pre, data, post = line.split('"', 2)
                parts.extend(pre.split() + [data] + post.split())
               
            else :
                parts.extend(line.split())

            # handle multi-line statements...
            if '(' in parts :
                assert not multiline_start

                log.debug("%s:%d: Start of multi-line statement: %s", filename, lineno, line)

                multiline_start = (lineno, comment)
                multiline_line = raw_line
                multiline_parts = []

            if multiline_start:
                log.debug("%s:%d: Multi-line statement: %s", filename, lineno, line)
                
                # XXX: some better way to do this
                multiline_parts.extend([part for part in parts if part not in set('()')])
                multiline_line += raw_line

            if ')' in parts :
                assert multiline_start

                log.debug("%s:%d: End of multi-line statement: %s", filename, lineno, line)
                
                lineno, comment = multiline_start
                raw_line = multiline_line
                parts = multiline_parts

                multiline_start = multiline_line = multiline_parts = None
        
            # parse
            if multiline_start:
                pass
            else:
                yield ZoneLine(parts,
                        file    = filename,
                        lineno  = lineno,
                        line    = raw_line,
                        comment = comment,
                )

    def __init__ (self, parts, file=None, lineno=None, line=None, comment=None) :
        """
            parts : [str]       - list of parsed (quoted) fields
                                  if the line has a leading indent, the first field should be an empty string
        """

        self.parts = parts

        # source
        self.file = file
        self.lineno = lineno
        self.line = line

        # meta
        self.comment = comment

    def __unicode__ (self) :
        return u"{parts}{comment}".format(
                parts       = u'\t'.join(self.parts),
                comment     = u'\t; ' + self.comment if self.comment is not None else '',
        )

    def __repr__ (self) :
        return "{file}:{lineno}".format(file=self.file, lineno=self.lineno)


class ZoneDirective (object) :
    """
        An $DIRECTIVE in a zonefile.
    """
    
    # context
    line    = None
    origin  = None

    # fields
    directive   = None
    arguments   = None
    
    @classmethod
    def parse (cls, parts, **opts) :
        """
            Parse from ZoneLine.parts
        """

        directive = parts[0][1:].upper()
        arguments = parts[1:]

        return cls(directive, arguments, **opts)

    @classmethod
    def build (cls, directive, *arguments, **opts):
        """
            Build directive from optional parts
        """

        return cls(directive, arguments, **opts)

    @classmethod
    def INCLUDE (cls, path, origin=None, **opts):
        """
            Build $INCLUDE "path" [origin]

            Note that origin should be a FQDN, or it will be interpreted relative to the current origin!
        """

        if origin:
            return cls.build('INCLUDE', path, origin, **opts)
        else:
            return cls.build('INCLUDE', path, **opts)

    def __init__ (self, directive, arguments, comment=None, line=None, origin=None):
        """
            directive       - uppercase directive name, withtout leading $
            arguments [str] - list of directive arguments
            comment         - optional trailing comment
            line            - optional associated ZoneLine
            origin          - context origin
        """

        self.directive  = directive
        self.arguments  = arguments

        self.comment = comment
        self.line = line
        self.origin = origin

    def __unicode__ (self):
        """
            Construct a zonefile-format line...
        """

        if self.comment:
            comment = '\t; ' + self.comment
        else :
            comment = ''
            
        return u"${directive}\t{arguments}{comment}".format(
                directive   = self.directive,
                arguments   = '\t'.join(zone_quote(argument) for argument in self.arguments),
                comment     = comment,
        )

    def __repr__ (self) :
        return '%s(%s)' % (self.__class__.__name__, ', '.join(repr(arg) for arg in (
            (self.directive, ) + tuple(self.arguments)
        )))


def process_generate (line, range, parts, **opts) :
    """
        Parse a
            $GENERATE <start>-<stop>[/<step>] lhs [ttl] [class] type rhs [comment]
        directive and yield expanded ZoneRecords.

        Raises ZoneLineError
    """
        

    try:
        range = pvl.dns.generate.parse_generate_range(range)
        
        lhs_func = pvl.dns.generate.parse_generate_field(parts[0], line=line)
        body = parts[1:-1]
        rhs_func = pvl.dns.generate.parse_generate_field(parts[-1], line=line)

    except ValueError as error:
        raise ZoneLineError(line, "{error}", error=error)
                    
    log.info("%s: generate %s: %s ... %s", line, range, lhs_func, rhs_func)

    for i in range :
        # build
        parts = [lhs_func(i)] + body + [rhs_func(i)]

        log.debug(" %03d: %r", i, parts)

        yield ZoneRecord.parse(parts, line=line, **opts)

def process_include (line, origin, include_filename, include_origin=None, **opts):
    """
        Parse a
            $INCLUDE filename [ origin ]
        directive and yield expanded ZoneRecords.

        Raises ZoneLineError.
    """

    path = os.path.join(os.path.dirname(line.file), include_filename)

    if include_origin:
        origin = pvl.dns.labels.join(include_origin, origin)

    for record in ZoneRecord.load(open(path), origin, **opts) :
        yield record

class ZoneRecord (object) :
    """
        A record from a zonefile.
    """

    @classmethod
    def load (cls, file, origin,
            ttl=None,
    ):
        """
            Yield ZoneRecords from a file. Processes any ZoneDirectives.
        """
        
        name = None

        for line in ZoneLine.parse(file):
            if not line.parts:
                log.debug("%s: Skip empty line: %r", line, line)
                continue

            elif line.parts[0].startswith('$'):
                directive = ZoneDirective.parse(line.parts,
                        origin      = origin,
                        line        = line,
                        comment     = line.comment,
                )

                log.debug("%s: %s", line, directive)

                if directive.directive == 'ORIGIN':
                    directive_origin, = directive.arguments
                    
                    log.info("%s: $ORIGIN %s <- %s", line, directive_origin, origin)
                    
                    origin = pvl.dns.labels.join(origin, directive_origin)

                elif directive.directive == 'TTL' :
                    directive_ttl, = directive.arguments
                    
                    log.info("%s: $TTL %s <- %s", line, directive_ttl, ttl)
                    
                    ttl = int(directive_ttl)

                elif directive.directive == 'GENERATE' :
                    for record in process_generate(line, directive.arguments,
                            name        = name,
                            ttl         = ttl,
                            origin      = origin,
                    ) :
                        yield record

                elif directive.directive == 'INCLUDE' :
                    for record in process_include(line, origin, 
                            ttl         = ttl,
                            *directive.arguments
                    ):
                        yield record

                else :
                    log.warn("%s: skip unknown control record: %r", line, directive)
                    yield line, None

            else:
                record = ZoneRecord.parse(line.parts,
                        name        = name,
                        ttl         = ttl,
                        comment     = line.comment,
                        line        = line,
                        origin      = origin,
                )
                
                log.debug("%s: %s", line, record)

                # keep name across lines
                name = record.name
                
                yield record

    @classmethod
    def parse (cls, parts, name=None, ttl=None, line=None, **opts) :
        """
            Build a ZoneRecord from ZoneLine.parts.

                name    - context for name, if continuing previous line
                ttl     - context for ttl, if using $TTL

            Return: (name, ZoneRecord)
        """
        
        parts = list(parts)

        # first field is either name or leading whitespace
        leading = parts.pop(0)
        
        if leading:
            name = leading

        if len(parts) < 2 :
            raise ZoneLineError(line, "Too few parts to parse: {parts!r}", parts=parts)

        # parse ttl/cls/type
        _cls = None

        if parts and parts[0][0].isdigit() :
            ttl = parts.pop(0)

        if parts and parts[0].upper() in ('IN', 'CH') :
            _cls = parts.pop(0)

        # always have type
        type = parts.pop(0)

        # remaining parts are data
        data = parts

        log.debug("  ttl=%r, cls=%r, type=%r, data=%r", ttl, _cls, type, data)

        # optional subclass for build()
        cls = ZONE_RECORD_TYPES.get(type.upper(), cls)

        return cls.build(name, type, *data, ttl=ttl, cls=_cls, line=line, **opts)
    
    @classmethod
    def build (_cls, name, type, *data, **opts):
        """
            Simple interface to build ZoneRecord from required parts. All optional fields must be given as keyword arguments.

            Normalizes all fields to strs.
        """
        
        # keyword-only arguments
        ttl = opts.pop('ttl', None)
        cls = opts.pop('cls', None)

        if name:
            name = str(name)
        else:
            name = None

        if ttl or ttl == 0:
            ttl = int(ttl)
        else:
            ttl = None

        if cls:
            cls = cls.upper()
        else:
            cls = None

        type = type.upper()

        data = [unicode(item) for item in data]

        return _cls(name, ttl, cls, type, data, **opts)

    @classmethod
    def A (_cls, name, ip4, **opts):
        """
            Build from ipaddress.IPv4Address.
        """

        return _cls.build(name, 'A', ipaddress.IPv4Address(unicode(ip4)), **opts)

    @classmethod
    def AAAA (_cls, name, ip6, **opts):
        """
            Build from ipaddress.IPv6Address.
        """

        return _cls.build(name, 'AAAA', ipaddress.IPv6Address(unicode(ip6)), **opts)

    @classmethod
    def CNAME (_cls, name, alias, **opts):
        return _cls.build(name, 'CNAME', alias, **opts)

    @classmethod
    def TXT (_cls, name, text, **opts):
        """
            Build from quoted (unicode) value.
        """

        return _cls.build(name, 'TXT', zone_quote(unicode(text)), **opts)

    @classmethod
    def PTR (_cls, name, ptr, **opts):
        return _cls.build(name, 'PTR', ptr, **opts)

    @classmethod
    def MX (_cls, name, priority, mx, **opts):
        """
                priority    : int
                mx          : str       - hostname
        """

        return _cls.build(name, 'MX', int(priority), str(mx), **opts)

    def __init__ (self, name, ttl, cls, type, data, comment=None, origin=None, line=None):
        """
            Using strict field ordering.

                name            - local label with respect to $ORIGIN
                                  may also be @ to refer to $ORIGIN
                ttl             - int TTL for record, default to implict $TTL
                cls             - uppercase class for record, default to IN
                type            - uppercase type for record, required
                data [...]      - list of data fields, interpretation varies by type
                comment         - optional comment to include in zone
                origin          - track implicit $ORIGIN or XXX: previous record state
                line            - associated ZoneLine
        """

        self.name = name
        self.ttl = ttl
        self.cls = cls
        self.type = type
        self.data = data
        
        self.comment = comment
        self.origin = origin
        self.line = line

    def __unicode__ (self) :
        """
            Construct a zonefile-format line..."
        """

        if self.comment :
            comment = '\t; ' + self.comment
        else :
            comment = ''
            
        return u"{name:25} {ttl:4} {cls:2} {type:5} {data}{comment}".format(
                name    = '' if self.name is None else self.name,
                ttl     = '' if self.ttl is None else self.ttl,
                cls     = '' if self.cls is None else self.cls,
                type    = self.type,
                data    = ' '.join(unicode(data) for data in self.data),
                comment = comment,
        )

    def __repr__ (self) :
        return '%s(%s)' % (self.__class__.__name__, ', '.join(repr(arg) for arg in (
            self.name, self.ttl, self.cls, self.type, self.data
        )))

class ZoneRecordSOA (ZoneRecord):
    """
        Specialized SOA record.
    """

    @classmethod
    def build (_cls, name, type,
            master, contact, serial, refresh, retry, expire, nxttl, 
            line=None,
            **opts
    ):
        assert type == 'SOA'

        if name != '@':
            raise ZoneLineError(line, "SOA should be @: {name}", name=name)
                
        return super(ZoneRecordSOA, _cls).build('@', 'SOA', 
            master, contact, serial, refresh, retry, expire, nxttl,
            **opts
        )

    def __init__ (self, name, ttl, cls, type, data, **opts):
        super(ZoneRecordSOA, self).__init__(name, ttl, cls, type, data, **opts)

        self.master, self.contact, self.serial, self.refresh, self.retry, self.expire, self.nxttl = data

ZONE_RECORD_TYPES = {
    'SOA':      ZoneRecordSOA,
}