urltree.py
changeset 51 a1da82870a6b
parent 50 e4fbf480fbee
child 52 2136cdc95b86
--- a/urltree.py	Mon Feb 09 04:38:23 2009 +0200
+++ b/urltree.py	Mon Feb 09 05:44:12 2009 +0200
@@ -19,16 +19,17 @@
         Represents the value of a ValueLabel... love these names
     """
 
-    def __init__ (self, label, value) :
+    def __init__ (self, label, value, is_default) :
         """
             Just store
         """
 
         self.label = label
         self.value = value
+        self.is_default = is_default
     
     def __str__ (self) :
-        return "%s=%r" % (self.label.key, self.value)
+        return "%s%s" % (self.label.key, "=%r" % (self.value, ) if not self.is_default else '')
 
     def __repr__ (self) :
         return "<%s>" % self
@@ -100,6 +101,14 @@
 
         abstract
 
+    def build_default (self, value_dict) :
+        """
+            Return an (is_default, value) tuple
+        """
+
+        abstract
+
+
 class EmptyLabel (Label) :
     """
         An empty label, i.e. just a slash in the URL
@@ -126,7 +135,10 @@
             return True
     
     def build (self, values) :
-        return str(self)
+        return ''
+
+    def build_default (self, values) :
+        return (False, '')
 
     def __str__ (self) :
         return ''
@@ -166,7 +178,10 @@
             return True
 
     def build (self, values) :
-        return str(self)
+        return self.name
+
+    def build_default (self, values) :
+        return (False, self.name)
 
     def __str__ (self) :
         return self.name
@@ -201,7 +216,7 @@
         # lookup the value obj to use
         value = values.get(self.key)
         
-        if not value and self.default :
+        if not value and self.default is not None :
             value = self.default
         
         elif not value :
@@ -212,6 +227,33 @@
         
         return value
 
+    def build_default (self, values) :
+        """
+            Check if we have a value in values, and return based on that
+
+            XXX: copy-paste from build()
+        """
+
+        # state
+        is_default = False
+
+        # lookup the value obj to use
+        value = values.get(self.key)
+        
+        if not value and self.default is not None :
+            is_default = True
+            value = self.default
+        
+        elif not value :
+            raise URLError("No value given for label %r" % (self.key, ))
+        
+        # convert value back to str
+        value = self.type.build(value)
+        
+        # return
+        return (is_default, value)
+
+
 class SimpleValueLabel (ValueLabel) :
     """
         A label that has a name and a simple string value
@@ -241,7 +283,7 @@
         
         # default?
         if value is None and self.default is not None :
-            return LabelValue(self, self.default)
+            return LabelValue(self, self.default, True)
         
         # only non-empty values!
         elif value :
@@ -252,7 +294,7 @@
             # convert with type
             value = self.type.parse(value)
 
-            return LabelValue(self, value)
+            return LabelValue(self, value, False)
 
     def __str__ (self) :
         return '{%s%s%s}' % (
@@ -457,6 +499,9 @@
 
         # query string
         self.query_args = dict()
+
+        # remove prepending root /
+        url_mask = url_mask.lstrip('/')
         
         # parse any query string
         # XXX: conflicts with regexp syntax
@@ -511,23 +556,31 @@
         # start with the defaults
         kwargs = self.defaults.copy()
 
+        # ...dict of those label values which are set to defaults
+        default_labels = {}
+
         # then add all the values
         for label_value in label_values :
             kwargs[label_value.label.key] = label_value.value
+            
+            # add key to default_values?
+            if label_value.is_default :
+                default_labels[label_value.label.key] = label_value.label
        
         # then parse all query args
         for key, value in request.get_args() :
-            # unknown key?
-            if key not in self.query_args :
-                # ignore?
-                if self.config.ignore_extra_args :
-                    continue
-                
-                else :
-                    raise URLError("Unrecognized query argument: %r" % (key, ))
-
-            # lookup spec
-            type, default = self.query_args[key]
+            # lookup in our defined query args
+            if key in self.query_args :
+                # lookup spec
+                type, default = self.query_args[key]
+            
+            # override URL params if they were not given
+            elif key in default_labels :
+                type, default = default_labels[key].type, None
+            
+            # be strict about extraneous query args?
+            elif self.config.ignore_extra_args :
+                raise URLError("Unrecognized query argument: %r" % (key, ))
 
             # normalize empty value to None
             if not value :
@@ -575,11 +628,25 @@
     
     def build (self, request, **values) :
         """
-            Build an absolute URL pointing to this target, with the given values
+            Build an absolute URL pointing to this target, with the given values. Default values are left off if they
+            are at the end of the URL.
         """
+        
+        # collect segments as a list of (is_default, segment) values
+        segments = [(False, request.page_prefix)] + [label.build_default(values) for label in self.label_path]
+        
+        # trim default items off the end
+        for is_default, segment in segments[::-1] :
+            if is_default :
+                segments.pop(-1)
+            
+            else :
+                break
 
-        # build URL from request page prefix and our labels
-        return request.page_prefix + '/'.join(label.build(values) for label in self.label_path)
+        assert segments
+        
+        # join
+        return '/'.join(segment for is_default, segment in segments)
 
     def __str__ (self) :
         return '/'.join(str(label) for label in self.label_path)
@@ -785,13 +852,10 @@
         """
             Adds the given URL to the tree. The URL must begin with a root slash.
         """
+
         # get url's label path
         path = url.get_label_path()
 
-        # should begin with root
-        root_label = path.pop(0)
-        assert root_label == self.root.label, "URL must begin with root"
-
         # add to root
         self.root.add_url(url, path)