urltree.py
author Tero Marttila <terom@fixme.fi>
Mon, 16 Feb 2009 20:02:28 +0200
changeset 80 94c493b7c046
parent 79 747554808944
child 81 847da3c265b5
permissions -rw-r--r--
start writing urltree tests
"""
    Tree-based URL mapping
"""

import re
import os.path

from qmsk.web import handler

class URLError (Exception) :
    """
        Error with an URL definition
    """

    pass

class LabelValue (object) :
    """
        Represents the value of a ValueLabel... love these names
    """

    def __init__ (self, label, value, is_default) :
        """
            Just store
        """

        self.label = label
        self.value = value
        self.is_default = is_default
    
    def __str__ (self) :
        return "%s%s" % (self.label.key, "=%r" % (self.value, ) if not self.is_default else '')

    def __repr__ (self) :
        return "<%s>" % self

class Label (object) :
    """
        Base class for URL labels (i.e. the segments of the URL between /s)
    """

    @staticmethod
    def parse (mask, defaults, config) :
        """
            Parse the given label-segment, and return a *Label instance. Config is the URLConfig to use
        """

        # empty?
        if not mask :
            return EmptyLabel()

        # simple value?
        match = SimpleValueLabel.EXPR.match(mask)

        if match :
            # key
            key = match.group('key')

            # type
            type_name = match.group("type")
            
            # lookup type, None -> default
            type = config.get_type(type_name)

            # defaults?
            default = defaults.get(key)

            if not default :
                default = match.group('default')

                if default :
                    # apply type to default
                    default = type.parse(default)

            # build
            return SimpleValueLabel(key, type_name, type, default)
        
        # static?
        match = StaticLabel.EXPR.match(mask)

        if match :
            return StaticLabel(match.group('name'))

        # invalid
        raise URLError("Invalid label: %r" % (mask, ))
    
    def match (self, value=None) :
        """
            Match this label against the given value, returning either True to match without a value, a LabelValue
            object, or boolean false to not match.

            If value is None, this means that only a default value should be returned.
        """

        abstract
    
    def build (self, value_dict) :
        """
            Return a string representing this label, using the values in the given value_dict if needed
        """

        abstract

    def build_default (self, value_dict) :
        """
            Return an (is_default, value) tuple
        """

        abstract


class EmptyLabel (Label) :
    """
        An empty label, i.e. just a slash in the URL
    """
    
    def __eq__ (self, other) :
        """
            Just compares type
        """

        return isinstance(other, EmptyLabel)
    
    def match (self, value=None) :
        """
            Match empty string -> True
        """
        
        # no default
        if value is None :
            return False
        
        # only empty segments
        if value == '' :
            return True
    
    def build (self, values) :
        return ''

    def build_default (self, values) :
        return (False, '')

    def __str__ (self) :
        return ''

class StaticLabel (Label) :
    """
        A simple literal Label, used for fixed terms in the URL
    """

    EXPR = re.compile(r'^(?P<name>[a-zA-Z0-9_.-]+)$')

    def __init__ (self, name) :
        """
            The given name is the literal name of this label
        """

        self.name = name

    def __eq__ (self, other) :
        """
            Compares names
        """

        return isinstance(other, StaticLabel) and self.name == other.name
    
    def match (self, value=None) :
        """
            Match exactly -> True
        """

        # no defaults
        if value is None :
            return False
        
        # match name
        if value == self.name :
            return True

    def build (self, values) :
        return self.name

    def build_default (self, values) :
        return (False, self.name)

    def __str__ (self) :
        return self.name

class ValueLabel (Label) :
    """
        A label with a key and a value

        XXX: do we even need this?
    """

    def __init__ (self, key, default) :
        """
            Set the key and default value. Default value may be None if there is no default value defined
        """

        self.key = key
        self.default = default

    def __eq__ (self, other) :
        """
            Compares keys
        """

        return isinstance(other, ValueLabel) and self.key == other.key
    
    def build (self, values) :
        """
            Return either the assigned value from values, our default value, or raise an error
        """
        
        # just proxy to build_default
        return self.build_default(values)[1]

    def build_default (self, values) :
        """
            Check if we have a value in values, and return based on that
        """
 
        # state
        is_default = False
        
        # value given?
        if self.key not in values or values[self.key] is None :
            # error on missing non-default
            if self.default is None :
                raise URLError("No value given for label %r" % (self.key, ))
            
            # use default
            else :
                is_default = True
                value = self.default
        
        else :
            # lookup the value obj to use
            value = values[self.key]
        
        # convert value back to str
        value = self.type.build(value)
        
        # return
        return (is_default, value)

class SimpleValueLabel (ValueLabel) :
    """
        A label that has a name and a simple string value
    """

    EXPR = re.compile(r'^\{(?P<key>[a-zA-Z_][a-zA-Z0-9_]*)(:(?P<type>[a-zA-Z_][a-zA-Z0-9_]*))?(=(?P<default>[^}]*))?\}$')

    def __init__ (self, key, type_name, type, default) :
        """
            The given key is the name of this label's value.

            The given type_name is is None for the default type, otherwise the type's name. Type is a URLType.
        """

        # type
        self.type_name = type_name
        self.type = type

        # store
        self.key = key
        self.default = default
        
    def match (self, value=None) :
        """
            Match -> LabelValue
        """
        
        # default?
        if value is None and self.default is not None :
            return LabelValue(self, self.default, True)
        
        # only non-empty values!
        elif value :
            # test
            if not self.type.test(value) :
                return False

            # convert with type
            value = self.type.parse(value)

            return LabelValue(self, value, False)

    def __str__ (self) :
        return '{%s%s%s}' % (
            self.key, 
            (':%s' % (self.type_name, ) if self.type_name is not None else ''),
            '=%s' % (self.default, ) if self.default is not None else '',
        )

class URLType (object) :
    """
        Handles the type-ness of values in the URL
    """

    def test (self, value) :
        """
            Tests if the given value is valid for this type.

            Defaults to calling parse(), and returning False on errors, True otherwise
        """
        
        try :
            self.parse(value)

        except :
            return False

        else :
            return True
    
    def parse (self, value) :
        """
            Parse the given value, which was tested earlier with test(), and return the value object
        """

        abstract
    
    def append (self, old_value, value) :
        """
            Handle multiple values for this type, by combining the given old value and new value (both from parse).

            Defaults to raise an error
        """

        raise URLError("Multiple values for argument")
   
    def build (self, obj) :
        """
            Reverse of parse(), return an url-value built from the given object
        """

        abstract

    def build_multi (self, obj) :
        """
            Return a list of string values for the given object value (as from parse/append).

            Defaults to return [self.build(obj)]
        """

        return [self.build(obj)]

class URLStringType (URLType) :
    """
        The default URLType, just plain strings.
        
        XXX: decodeing here, or elsewhere?
    """

    def parse (self, value) :
        """
            Identitiy
        """

        return value

    def build (self, obj) :
        return str(obj)

class URLIntegerType (URLType) :
    """
        A URLType for simple integers
    """

    def __init__ (self, allow_negative=True, allow_zero=True, max=None) :
        """
            Pass in allow_negative=False to disallow negative numbers, allow_zero=False to disallow zero, or non-zero
            max to specifiy maximum value
        """

        self.allow_negative = allow_negative
        self.allow_zero = allow_zero
        self.max = max
    
    def _validate (self, value) :
        """
            Test to make sure value fits our criteria
        """

        # negative?
        if not self.allow_negative and value < 0 :
            raise ValueError("value is negative")
        
        # zero?
        if not self.allow_zero and value == 0 :
            raise ValueError("value is zero")
        
        # max?
        if self.max is not None and value > max :
            raise ValueError("value is too large: %d" % value)
        
        return value

    def parse (self, value) :
        """
            Convert str -> int
        """

        return self._validate(int(value))
    
    def build (self, obj) :
        """
            Convert int -> str
        """

        return unicode(self._validate(obj))
    
class URLListType (URLType) :
    """
        A list of strings
    """

    def parse (self, value) :
        return [value]
    
    def append (self, old_value, value) :
        return old_value + value

    def build_multi (self, obj) :
        return obj

class URLConfig (object) :
    """
        Global configuration relevant to all URLs. This can be used to construct a set of URLs and then create an
        URLTree out of them. Simply call the url_config() instace with the normal URL arguments (except, of course,
        config), and finally just pass the url_config to URLTree (it's iterable).

        XXX: rename to URLFactory?
    """

    # built-in type codes
    BUILTIN_TYPES = {
        # default - string
        None    : URLStringType(),

        # string
        'str'   : URLStringType(),

        # integer
        'int'   : URLIntegerType(),

        # list of strs
        'list'  : URLListType(),
    }

    def __init__ (self, type_dict=None, ignore_extra_args=True) :
        """
            Create an URLConfig for use with URL

            If type_dict is given, it should be a dict of { type_names: URLType }, and they will be available for
            type specifications in addition to the defaults. This will call type._init_name with the key, so do
            *not* initialize the name yourself.

            If ignore_extra_args is given, unrecognized query arguments will be ignored.
        """

        # build our type_dict
        self.type_dict = self.BUILTIN_TYPES.copy()
        
        # apply the given type_dict
        if type_dict :
            # merge
            self.type_dict.update(type_dict)

        # init
        self.ignore_extra_args = ignore_extra_args
        self.urls = []
        
    def get_type (self, type_name=None) :
        """
            Lookup an URLType by type_name, None for default.
        """
        
        # lookup + return
        return self.type_dict[type_name]

    def __call__ (self, *args, **kwargs) :
        """
            Return new URL object with this config and the given args, adding it to our list of urls
        """
        
        # build
        url = URL(self, *args, **kwargs)
        
        # store
        self.urls.append(url)

        # return
        return url
    
    def __iter__ (self) :
        """
            Returns all defined URLs
        """

        return iter(self.urls)

class URL (object) :
    """
        Represents a specific URL
    """


    def __init__ (self, config, url_mask, handler, **defaults) :
        """
            Create an URL using the given URLConfig, with the given url mask, handler, and default values.
        """

        # store
        self.config = config
        self.url_mask = url_mask
        self.handler = handler
        self.defaults = defaults

        # query string
        self.query_args = dict()

        # remove prepending root /
        url_mask = url_mask.lstrip('/')
        
        # parse any query string
        # XXX: conflicts with regexp syntax
        if '/?' in url_mask :
            url_mask, query_mask = url_mask.split('/?')
        
        else :
            query_mask = None

        # build our label path
        self.label_path = [Label.parse(mask, defaults, config) for mask in url_mask.split('/')]

        # build our query args list
        if query_mask :
            # split into items
            for query_item in query_mask.split('&') :
                # parse default
                if '=' in query_item :
                    query_item, default = query_item.split('=')

                else :
                    default = None
                
                # parse type
                if ':' in query_item :
                    query_item, type = query_item.split(':')
                else :
                    type = None
                
                # parse key
                key = query_item

                # type
                type = self.config.get_type(type)

                # add to query_args as (type, default) tuple
                self.query_args[key] = (type, type.parse(default) if default else default)
         
    def get_label_path (self) :
        """
            Returns a list containing the labels in this url
        """
        
        # copy self.label_path
        return list(self.label_path)

    def execute (self, request, label_values) :
        """
            Invoke the handler, using the given label values
        """
        
        # start with the defaults
        kwargs = self.defaults.copy()

        # ...dict of those label values which are set to defaults
        default_labels = {}

        # then add all the values
        for label_value in label_values :
            kwargs[label_value.label.key] = label_value.value
            
            # add key to default_values?
            if label_value.is_default :
                default_labels[label_value.label.key] = label_value.label
       
        # then parse all query args
        for key, value in request.get_args() :
            # lookup in our defined query args
            if key in self.query_args :
                # lookup spec
                type, default = self.query_args[key]
            
            # override URL params if they were not given
            elif key in default_labels :
                type, default = default_labels[key].type, None
            
            # be strict about extraneous query args?
            elif not self.config.ignore_extra_args :
                raise URLError("Unrecognized query argument: %r" % (key, ))
            
            # ignore
            else :
                continue

            # normalize empty value to None
            if not value :
                value = None

            else :
                # parse value
                value = type.parse(value)

            # set default?
            if value is None :
                if default :
                    value = default

                if default == '' :
                    # do not pass key at all
                    continue

                # otherwise, fail
                raise URLError("No value given for required argument: %r" % (key, ))
            
            # already have a non-default value?
            if key in kwargs and key not in default_labels :
                # append to old value
                kwargs[key] = type.append(kwargs[key], value)

            else :
                # set key
                kwargs[key] = value
        
        # then check all query args
        for key, (type, default) in self.query_args.iteritems() :
            # skip those already present
            if key in kwargs :
                continue

            # apply default?
            if default is None :
                raise URLError("Missing required argument: %r" % (key, ))
            
            elif default == '' :
                # skip empty default
                continue

            else :
                # set default
                kwargs[key] = default

        # execute the handler
        return self.handler(request, **kwargs)
    
    def build (self, request, **values) :
        """
            Build an absolute URL pointing to this target, with the given values. Default values are left off if they
            are at the end of the URL.

            Values given as None are ignored.
        """
        
        # collect segments as a list of (is_default, segment) values
        segments = [(False, request.page_prefix)] + [label.build_default(values) for label in self.label_path]
        
        # trim default items off the end
        for is_default, segment in segments[::-1] :
            if is_default :
                segments.pop(-1)
            
            else :
                break

        assert segments
        
        # join
        url = '/'.join(segment for is_default, segment in segments if segment is not None)
        
        # build query args as { key -> [value] }
        query_args = dict((key, type.build_multi(values[key])) for key, (type, default) in self.query_args.iteritems() if key in values and values[key] is not None)

        return "%s%s" % (url, '?%s' % ('&'.join('%s=%s' % (key, value) for key, values in query_args.iteritems() for value in values)) if query_args else '')

    def __str__ (self) :
        return '/'.join(str(label) for label in self.label_path)
    
    def __repr__ (self) :
        return "URL(%r, %r)" % (str(self), self.handler)

class URLNode (object) :
    """
        Represents a node in the URLTree
    """

    def __init__ (self, parent, label) :
        """
            Initialize with the given parent and label, empty children dict
        """
        
        # the parent URLNode
        self.parent = parent

        # this node's Label
        self.label = label

        # list of child URLNodes
        self.children = []

        # this node's URL, set by add_url for an empty label_path
        self.url = None

    def _build_child (self, label) :
        """
            Build, insert and return a new child Node
        """
        
        # build new child
        child = URLNode(self, label)
        
        # add to children
        self.children.append(child)

        # return
        return child

    def add_url (self, url, label_path) :
        """
            Add a URL object to this node under the given path. Uses recursion to process the path.

            The label_path argument is a (partial) label path as returned by URL.get_label_path.

            If label_path is empty (len zero, or begins with EmptyLabel), then the given url is assigned to this node, if no
            url was assigned before.
        """
        
        # matches this node?
        if not label_path or isinstance(label_path[0], EmptyLabel) :
            if self.url :
                raise URLError(url, "node already defined")

            else :
                # set
                self.url = url

        else :
            # pop child label from label_path
            child_label = label_path.pop(0)

            # look for the child to recurse into
            child = None

            # look for an existing child with that label
            for child in self.children :
                if child.label == child_label :
                    # found, use this
                    break

            else :
                # build a new child
                child = self._build_child(child_label)

            # recurse to handle the rest of the label_path
            child.add_url(url, label_path)
    
    def match (self, label_path) :
        """
            Locate the URL object corresponding to the given label_path value under this node.

            Returns a (url, label_values) tuple
        """

        # determine value to use
        value = None

        # empty label_path?
        if not label_path or label_path[0] == '' :
            # the search ends at this node
            if self.url :
                # this URL is the best match
                return (self.url, [])
            
            elif not self.children :
                # incomplete URL
                raise URLError("no URL handler defined for this Node")
            
            else :
                # use default value, i.e. Label.match(None)
                label = None

        else :
            # pop the next label from the label path
            label = label_path.pop(0)

        # return one match...
        match = value = None

        # recurse through our children, DFS
        for child in self.children :
            # match value
            value = child.label.match(label)

            # skip those that don't match at all
            if not value :
                continue;
            
            # already found a match? :/
            if match :
                raise URLError("Ambiguous URL")

            # ok, but continue looking to make sure there's no ambiguous URLs
            match = child
        
        # found something?
        if not match :
            raise URLError("No child found for label: %s + %s + %s" % (self.get_url(), label, '/'.join(str(l) for l in label_path)))

        # ok, recurse into the match
        url, label_value = match.match(label_path)

        # add our value?
        if isinstance(value, LabelValue) :
            label_value.append(value)

        # return the match
        return url, label_value

    def get_url (self) :
        """
            Returns the URL for this node, by iterating over our parents
        """
        
        # URL segments in reverse order
        segments = ['']
        
        # start with ourself
        node = self
        
        # iterate up to root
        while node :
            segments.append(str(node.label))

            node = node.parent

        # reverse
        segments.reverse()

        # return
        return '/'.join(segments)

    def dump (self, indent=0) :
        """
            Returns a multi-line string representation of this Node
        """

        return '\n'.join([
            "%-45s%s" % (
                ' '*indent + str(self.label) + ('/' if self.children else ''), 
                (' -> %r' % self.url) if self.url else ''
            )
        ] + [
            child.dump(indent + 4) for child in self.children
        ])

    def __str__ (self) :
        return "%s/[%s]" % (self.label, ','.join(str(child) for child in self.children))

class URLTree (handler.RequestHandler) :
    """
        Map requests to handlers, using a defined tree of URLs
    """

    def __init__ (self, url_list) :
        """
            Initialize the tree using the given list of URLs
        """

        # root node
        self.root = URLNode(None, EmptyLabel())
        
        # just add each URL
        for url in url_list :
            self.add_url(url)

    def add_url (self, url) :
        """
            Adds the given URL to the tree. The URL must begin with a root slash.
        """

        # get url's label path
        path = url.get_label_path()

        # add to root
        self.root.add_url(url, path)
        
    def match (self, url) :
        """
            Find the URL object best corresponding to the given url, matching any ValueLabels.

            Returns an (URL, [LabelValue]) tuple.

            XXX: handle unicode on URLs
        """

        # split it into labels
        path = url.split('/')
        
        # empty URL is empty
        if url :
            # ensure that it doesn't start with a /
            assert not self.root.label.match(path[0]), "URL must not begin with root"

        # just match starting at root
        return self.root.match(path)

    def handle_request (self, request) :
        """
            Looks up the request's URL, and invokes its handler
        """
        
        # get the requested URL
        request_url = request.get_page_name()

        # find the URL+values to use
        url, label_values = self.match(request_url)

        # let the URL handle it
        return url.execute(request, label_values)