pvl/login/pubtkt.py
author Tero Marttila <terom@paivola.fi>
Sun, 07 Sep 2014 14:21:56 +0300
changeset 424 e77e967d59b0
parent 371 8c17eb11858f
permissions -rw-r--r--
hgignore: use glob; ignore snmp mibs
import base64
import calendar
import datetime
import ipaddr
import hashlib
import M2Crypto

import logging; log = logging.getLogger('pvl.login.pubtkt')

def datetime2unix (dt) :
    """
        datetime.datetime -> float
    """

    return calendar.timegm(dt.utctimetuple())

def unix2datetime (unix) :
    return datetime.datetime.utcfromtimestamp(unix)

class Error (Exception) :
    """
        Error
    """

    def __init__ (self, error) :
        self.error = error

    def __unicode__ (self) :
        return u"{doc}: {self.error}".format(self=self, doc=self.__doc__.strip())

class ParseError (Error) :
    """
        Unable to parse PubTkt from cookie
    """

class VerifyError (Error) :
    """
        Invalid login token sigunature
    """

    def __init__ (self, pubtkt, error) :
        self.pubtkt = pubtkt
        self.error = error

class ExpiredError (Error) :
    """
        Login token has expired
    """

    def __init__ (self, pubtkt, expire) :
        self.pubtkt = pubtkt
        self.error = expire

class RenewError (Error) :
    """
        Unable to renew login token
    """

class ServerError (Error) :
    """
        Login request from invalid server
    """
    
class ServerKeys (object) :
    @classmethod
    def config (cls, public_key, private_key) :
        return cls(
                public  = M2Crypto.RSA.load_pub_key(public_key),
                private = M2Crypto.RSA.load_key(private_key),
        )

    def __init__ (self, public, private) :
        self.public = public
        self.private = private

class PubTkt (object) :
    @staticmethod
    def now () :
        return datetime.datetime.utcnow()

    @classmethod
    def load (cls, cookie, public_key) :
        """
            Load and verify a pubtkt from a cookie.

            Raise ParseError, VerifyError.
        """
        
        pubtkt, hash, sig = cls.parse(cookie)

        log.debug("parsed %s hash=%s sig=%s", pubtkt, hash.encode('hex'), sig.encode('hex'))
        
        try :
            if not public_key.verify(hash, sig, 'sha1') :
                raise VerifyError(pubtkt, "Unable to verify signature")
        except M2Crypto.RSA.RSAError as ex :
            raise VerifyError(pubtkt, str(ex))
        

        log.debug("checking expiry %s", pubtkt.validuntil)
        
        if not pubtkt.valid() :
            raise ExpiredError(pubtkt, pubtkt.validuntil)

        return pubtkt

    @classmethod
    def parse (cls, cookie) :
        """
            Load a pubtkt from a cookie

            Raises ParseError.
        """
        
        if ';sig=' in cookie :
            data, sig = cookie.rsplit(';sig=', 1)
        else :
            raise ParseError("Missing signature")
        
        try :
            sig = base64.b64decode(sig)
        except (ValueError, TypeError) as ex :
            raise ParseError("Invalid signature")

        hash = hashlib.sha1(data).digest()

        try :
            attrs = dict(field.split('=', 1) for field in data.split(';'))
        except ValueError as ex :
            raise ParseError(str(ex))
        
        if 'uid' not in attrs or 'validuntil' not in attrs :
            raise ParseError("Missing parameters in cookie (uid, validuntil)")

        try :
            return cls.build(**attrs), hash, sig
        except TypeError as ex :
            raise ParseError("Invalid or missing parameters in cookie")
        except ValueError as ex :
            raise ParseError(str(ex))
    
    @classmethod
    def build (cls, uid, validuntil, cip=None, tokens=None, udata=None, graceperiod=None, bauth=None) :
        """
            Build a pubtkt from items.

            Raises TypeError or ValueError..
        """
        
        return cls(uid,
                validuntil  = unix2datetime(int(validuntil)),
                cip         = ipaddr.IPAddress(cip) if cip else None,
                tokens      = tokens.split(',') if tokens else (),
                udata       = udata,
                graceperiod = unix2datetime(int(graceperiod)) if graceperiod else None,
                bauth       = bauth,
        )

    @classmethod
    def new (cls, uid, valid, grace=None, **opts) :
        now = cls.now()

        return cls(uid, now + valid,
            graceperiod = now + grace if grace else None,
            **opts
        )

    def update (self, valid, grace, cip=None, tokens=None, udata=None, bauth=None) :
        now = self.now()

        return type(self)(self.uid, now + valid,
            graceperiod = now + grace if grace else None,
            cip         = self.cip if cip is None else cip,
            tokens      = self.tokens if tokens is None else tokens,
            udata       = self.udata if udata is None else udata,
            bauth       = self.bauth if bauth is None else bauth,
        )

    def __init__ (self, uid, validuntil, cip=None, tokens=(), udata=None, graceperiod=None, bauth=None) :
        self.uid = uid
        self.validuntil = validuntil
        self.cip = cip
        self.tokens = tokens
        self.udata = udata
        self.graceperiod = graceperiod
        self.bauth = bauth

    def iteritems (self) :
        yield 'uid', self.uid
        yield 'validuntil', int(datetime2unix(self.validuntil))

        if self.cip :
            yield 'cip', self.cip
        
        if self.tokens :
            yield 'tokens', ','.join(str(token) for token in self.tokens)
        
        if self.udata :
            yield 'udata', self.udata
        
        if self.graceperiod :
            yield 'graceperiod', int(datetime2unix(self.graceperiod))
        
        if self.bauth :
            yield 'bauth', self.bauth

    def __str__ (self) :
        """
            The (unsigned) pubtkt
        """

        return ';'.join('%s=%s' % (key, value) for key, value in self.iteritems())

    def sign (self, private_key) :
        data = str(self)
        hash = hashlib.sha1(data).digest()
        sign = private_key.sign(hash, 'sha1')

        return '%s;sig=%s' % (self, base64.b64encode(sign))

    def valid (self) :
        """
            Return remaining ticket validity.
        """

        now = self.now()

        if self.validuntil > now :
            return self.validuntil - now
        else :
            return False

    def grace (self) :
        """
            Return remaining grace period.
        """
        
        now = self.now()
        
        if not self.graceperiod :
            return None

        elif now < self.graceperiod :
            # still valid
            return None

        elif now < self.validuntil :
            # positive
            return self.validuntil - now

        else :
            # expired
            return False

    def remaining (self) :
        """
            Return remaining validity before grace.
        """

        now = self.now()
        
        if not self.graceperiod :
            return self.valid()

        elif now < self.graceperiod :
            return self.graceperiod - now

        else :
            # expired
            return False

    def grace_period (self) :
        """
            Return the length of the grace period.
        """

        if self.graceperiod :
            return self.validuntil - self.graceperiod
        else :
            return None
    
    def renew (self, valid, grace=None) :
        if not self.valid() :
            raise ExpiredError(self, "Unable to renew expired pubtkt")

        now = self.now()

        self.validuntil = now + valid
        self.graceperiod = now + grace if grace else None