#!/usr/bin/env python
"""
Process zonefiles.
"""
import codecs
import datetime
import logging
import math
import os.path
log = logging.getLogger('pvl.dns.zone')
class ZoneError (Exception) :
pass
class ZoneLineError (ZoneError) :
"""
ZoneLine-related error
"""
def __init__ (self, line, msg, *args, **kwargs) :
super(ZoneLineError, self).__init__("%s: %s" % (line, msg.format(*args, **kwargs)))
class ZoneLine (object) :
"""
A line parsed from a zonefile.
"""
@classmethod
def load (cls, file, filename=None, line_timestamp_prefix=None) :
"""
Yield ZoneLines lexed from a file.
"""
if not filename :
filename = file.name
multiline_start = None
multiline_parts = None
for lineno, raw_line in enumerate(file) :
# possible mtime prefix for line
timestamp = None
if line_timestamp_prefix :
if ': ' not in raw_line :
raise ZoneError("%s:%d: Missing timestamp prefix: %s" % (filename, lineno, raw_line))
# split prefix
prefix, raw_line = raw_line.split(': ', 1)
# parse it out
timestamp = datetime.datetime.strptime(prefix, cls.PARSE_DATETIME_FORMAT)
log.debug("%s:%d: ts=%r", filename, lineno, ts)
log.debug("%s:%d: %s", filename, lineno, raw_line)
# capture indent from raw line
indent = raw_line.startswith(' ') or raw_line.startswith('\t')
line = raw_line.strip()
# parse comment
if ';' in line:
line, comment = line.split(';', 1)
line = line.strip()
comment = comment.strip()
else :
comment = None
log.debug("%s:%d: indent=%r, line=%r, comment=%r", filename, lineno, indent, line, comment)
# split (quoted) fields
if '"' in line :
pre, data, post = line.split('"', 2)
parts = pre.split() + [data] + post.split()
else :
parts = line.split()
# handle multi-line statements...
if '(' in parts :
assert not multiline_start
log.warn("%s:%d: Start of multi-line statement: %s", filename, lineno, line)
multiline_start = (lineno, timestamp, indent, comment)
multiline_line = raw_line
multiline_parts = []
if multiline_start:
log.warn("%s:%d: Multi-line statement: %s", filename, lineno, line)
# XXX: some better way to do this
multiline_parts.extend([part for part in parts if part not in set('()')])
multiline_line += raw_line
if ')' in parts :
assert multiline_start
log.warn("%s:%d: End of multi-line statement: %s", filename, lineno, line)
lineno, timestamp, indent, comment = multiline_start
raw_line = multiline_line
parts = multiline_parts
multiline_start = multiline_line = multiline_parts = None
# parse
if multiline_start:
pass
else:
yield ZoneLine(filename, lineno, raw_line, indent, parts, comment, timestamp=timestamp)
file = None
lineno = None
# data
indent = None # was the line indented?
parts = None # split line fields
# optional
timestamp = None
comment = None
PARSE_DATETIME_FORMAT = '%Y-%m-%d'
def __init__ (self, file, lineno, line, indent, parts, comment=None, timestamp=None) :
# source
self.file = file
self.lineno = lineno
self.line = line
# parse data
self.indent = indent
self.parts = parts
# metadata
self.timestamp = timestamp
self.comment = comment
def __str__ (self) :
return "{file}:{lineno}".format(file=self.file, lineno=self.lineno)
class ZoneRecord (object) :
"""
A record from a zonefile.
"""
# context
line = None # the underlying line
origin = None # possible $ORIGIN context
# record fields
name = None
ttl = None # optional
cls = None # optional
type = None
data = None # list of data fields
@classmethod
def load (cls, file, ttl=None, origin=None, **opts) :
"""
Parse ZoneRecord items from the given zonefile, ignoring non-record lines.
"""
name = None
for line in ZoneLine.load(file, **opts) :
if not line.parts :
log.debug("%s: skip empty line", line)
elif line.line.startswith('$') :
# control record
args = list(line.parts)
directive = args.pop(0)[1:]
if directive == 'ORIGIN' :
# update
origin, = args
log.info("%s: origin: %s", line, origin)
elif directive == 'TTL' :
ttl, = args
log.info("%s: ttl: %s", line, ttl)
elif directive == 'GENERATE' :
# process...
log.info("%s: generate: %s", line, args)
for record in process_generate(line, origin, args) :
yield record
elif directive == 'INCLUDE' :
include, = args
# XXX: this is probably not what we want...
path = os.path.join(os.path.dirname(file.name), include)
log.info("%s: include: %s: %s", line, include, path)
for record in cls.load(open(path)) :
yield record
else :
log.warning("%s: skip control record: %s", line, line.line)
else :
# normal record?
record = ZoneRecord.parse(line,
name = name,
origin = origin,
ttl = ttl,
comment = line.comment,
)
if record :
yield record
# keep name across lines
name = record.name
else :
# unknown
log.warning("%s: skip unknown line: %s", line, line.line)
@classmethod
def parse (cls, line, name=None, parts=None, ttl=None, **opts) :
"""
Build a ZoneRecord from a ZoneLine.
name - default for name, if continuing previous line
Return: (name, ZoneRecord)
"""
if parts is None :
parts = list(line.parts)
if not parts :
# skip
return
if line.indent :
# indented lines keep name from previous record
pass
else :
name = parts.pop(0)
if len(parts) < 2 :
raise ZoneLineError(line, "Too few parts to parse: {0!r}", line.data)
# parse ttl/cls/type
_cls = None
if parts and parts[0][0].isdigit() :
ttl = parts.pop(0)
if parts and parts[0].upper() in ('IN', 'CH') :
_cls = parts.pop(0)
# always have type
type = parts.pop(0)
# remaining parts are data
data = parts
log.debug(" ttl=%r, cls=%r, type=%r, data=%r", ttl, _cls, type, data)
return cls.build(line, name, ttl, _cls, type, data, **opts)
@classmethod
def build (_cls, line, name, ttl, cls, type, data, **opts) :
return _cls(name, type, data,
ttl = ttl,
cls = cls,
line = line,
**opts
)
@classmethod
def A (cls, name, ip, **opts) :
return cls(str(name), 'A', [str(ip)], **opts)
@classmethod
def CNAME (cls, name, host, **opts) :
return cls(str(name), 'CNAME', [str(host)], **opts)
@classmethod
def TXT (cls, name, text, **opts) :
return cls(str(name), 'TXT',
[u'"{0}"'.format(text.replace('"', '\\"'))],
**opts
)
@classmethod
def PTR (cls, name, ptr, **opts) :
return cls(str(name), 'PTR', [str(ptr)], **opts)
@classmethod
def MX (cls, name, priority, mx, **opts) :
return cls(str(name), 'MX', [int(priority), str(mx)], **opts)
def __init__ (self, name, type, data, origin=None, ttl=None, cls=None, line=None, comment=None) :
self.name = name
self.type = type
self.data = data
self.ttl = ttl
self.cls = cls
self.origin = origin
self.line = line
self.comment = comment
def __unicode__ (self) :
"""
Construct a zonefile-format line..."
"""
if self.comment :
comment = '\t; ' + self._comment
else :
comment = ''
return u"{name:25} {ttl:4} {cls:2} {type:5} {data}{comment}".format(
name = self.name or '',
ttl = self.ttl or '',
cls = self.cls or '',
type = self.type,
data = ' '.join(unicode(data) for data in self.data),
comment = comment,
)
def __repr__ (self) :
return '%s(%s)' % (self.__class__.__name__, ', '.join(repr(arg) for arg in (
self.name, self.type, self.data
)))
class SOA (ZoneRecord) :
@classmethod
def build (_cls, line, name, ttl, cls, type, data, **opts) :
assert name == '@'
return _cls(*data,
ttl = ttl,
cls = cls,
line = line,
**opts
)
def __init__ (self, master, contact, serial, refresh, retry, expire, nxttl, **opts) :
super(SOA, self).__init__('@', 'SOA',
[master, contact, serial, refresh, retry, expire, nxttl],
**opts
)
self.master = master
self.contact = contact
self.serial = serial
self.refresh = refresh
self.retry = retry
self.expire = expire
self.nxttl = nxttl
class OffsetValue (object) :
"""
Magic for $GENERATE offsets.
>>> OffsetValue(0)[1]
1
>>> OffsetValue(10)[5]
15
"""
def __init__ (self, base) :
self.base = base
def __getitem__ (self, offset) :
return self.base + offset
def parse_generate_field (field, line=None) :
"""
Parse a $GENERATE lhs/rhs field:
$
${<offset>[,<width>[,<base>]]}
\$
$$
Returns a wrapper that builds the field-value when called with the index.
>>> parse_generate_field("foo")(1)
'foo'
>>> parse_generate_field("foo-$")(1)
'foo-1'
>>> parse_generate_field("foo-$$")(1)
'foo-$'
>>> parse_generate_field("\$")(1)
'$'
>>> parse_generate_field("10.0.0.${100}")(1)
'10.0.0.101'
>>> parse_generate_field("foo-${0,2,d}")(1)
'foo-01'
"""
input = field
expr = []
while '$' in field :
# defaults
offset = 0
width = 0
base = 'd'
escape = False
# different forms
if '${' in field :
pre, body = field.split('${', 1)
body, post = body.split('}', 1)
# parse body
parts = body.split(',')
# offset
offset = int(parts.pop(0))
# width
if parts :
width = int(parts.pop(0))
# base
if parts :
base = parts.pop(0)
if parts:
# fail
raise ZoneLineError(line, "extra data in ${...} body: {0!r}", parts)
elif '$$' in field :
pre, post = field.split('$$', 1)
escape = True
elif '\\$' in field :
pre, post = field.split('\\$', 1)
escape = True
else :
pre, post = field.split('$', 1)
expr.append(pre)
if escape :
expr.append('$')
else :
# meta-format
fmt = '{value[%d]:0%d%s}' % (offset, width, base)
log.debug("field=%r -> pre=%r, fmt=%r, post=%r", field, expr, fmt, post)
expr.append(fmt)
field = post
# final
if field :
expr.append(field)
# combine
expr = ''.join(expr)
log.debug("%s: %s", input, expr)
# processed
def value_func (value) :
# magic wrapper to implement offsets
return expr.format(value=OffsetValue(value))
return value_func
def parse_generate_range (field) :
"""
Parse a <start>-<stop>[/<step>] field
"""
if '/' in field :
field, step = field.split('/')
step = int(step)
else :
step = 1
start, stop = field.split('-')
start = int(start)
stop = int(stop)
log.debug(" range: start=%r, stop=%r, step=%r", start, stop, step)
# inclusive
return range(start, stop + 1, step)
def process_generate (line, origin, parts) :
"""
Process a
$GENERATE <start>-<stop>[/<step>] lhs [ttl] [class] type rhs [comment]
directive into a series of ZoneResource's.
"""
range = parse_generate_range(parts.pop(0))
lhs_func = parse_generate_field(parts.pop(0), line=line)
rhs_func = parse_generate_field(parts.pop(-1), line=line)
body = parts
for i in range :
# build
parts = [lhs_func(i)] + body + [rhs_func(i)]
log.debug(" %03d: %r", i, parts)
# parse
yield ZoneRecord.parse(line, parts=parts, origin=origin)
def reverse_ipv4 (ip) :
"""
Return in-addr.arpa reverse for given IPv4 prefix.
"""
# parse
octets = tuple(int(part) for part in ip.split('.'))
for octet in octets :
assert 0 <= octet <= 255
return '.'.join([str(octet) for octet in reversed(octets)] + ['in-addr', 'arpa'])
def reverse_ipv6 (ip6) :
"""
Return ip6.arpa reverse for given IPv6 prefix.
"""
parts = [int(part, 16) for part in ip6.split(':')]
parts = ['{0:04x}'.format(part) for part in parts]
parts = ''.join(parts)
return '.'.join(tuple(reversed(parts)) + ( 'ip6', 'arpa'))
def fqdn (*parts) :
fqdn = '.'.join(str(part) for part in parts)
# we may be given an fqdn in parts
if not fqdn.endswith('.') :
fqdn += '.'
return fqdn
def reverse_label (prefix, address) :
"""
Determine the correct label for the given IP address within the reverse zone for the given prefix.
This includes all suffix octets (partially) covered by the prefix.
"""
assert prefix.version == address.version
hostbits = prefix.max_prefixlen - prefix.prefixlen
if prefix.version == 4 :
labelbytes = int(math.ceil(hostbits / 8.0))
labelraw = address.packed[-labelbytes:]
return '.'.join(reversed([str(ord(x)) for x in labelraw]))
else :
raise ValueError("unsupported address version: %s" % (prefix, ))