#
# authlib.py:           misc functions for authentication purposes
#
# Author:		Christopher Arndt <chris.arndt@web.de>
# Version:		1.2
# Date:			29.04.2002
# Copyright:		LGPL
"""Helper functions for authentication issues.

o check_passwd(name, passwd, pwfile='/etc/passwd')

    Validate given user, passwd pair against password database.
    
o getpwnam(name, pwfile='/etc/passwd'):
    
    Get password database entry for name

o login(name=None, pwdb=None, user_prompt=None, pass_prompt=None, max_tries=3)
    
    Generate a login screen and validate the login.

o passcrypt(string, salt=None, method='des', magic='$1$'):

    Password encryption (DES or MD5)

"""

__all__ = ['check_passwd', 'getpwnam', 'login', 'passcrypt']

__author__  = "Christopher Arndt <chris.arndt@web.de>"
__version__ = "1.0b"

import string, sys, termios, time, whrandom

try:
    import crypt
except ImportError:
    try:
        import fcrypt
        crypt = fcrypt
    except ImportError:
        raise ImportError, "Could import neither crypt nor fcrypt module."

try:
    from Crypto.Hash import MD5
    md5 = MD5
except ImportError:
    import md5


DES_SALT = list('./0123456789' 'ABCDEFGHIJKLMNOPQRSTUVWXYZ' 'abcdefghijklmnopqrstuvwxyz') 


def getpass(prompt="Password: "):
    """Prompts for a string with terminal echo turned off.
    
    Example from the Python Library Reference.
    Only available on POSIX systems.
    XXX use getpass.getpass() instead.
    """
    
    fd = sys.stdin.fileno()
    old = termios.tcgetattr(fd)
    new = termios.tcgetattr(fd)
    new[3] = new[3] & ~termios.ECHO          # lflags
    try:
	termios.tcsetattr(fd, termios.TCSADRAIN, new)
	passwd = raw_input(prompt)
	print
    finally:
        termios.tcsetattr(fd, termios.TCSADRAIN, old)
    return passwd


def getpwnam(name, pwfile=None):
    """Return pasword database entry for the given user name.
    
    Example from the Python Library Reference.
    """
    
    if not pwfile:
        pwfile = '/etc/passwd'

    f = open(pwfile)
    while 1:
        line = f.readline()
        if not line:
            f.close()
            raise KeyError, name
        entry = tuple(line.strip().split(':', 6))
        if entry[0] == name:
            f.close()
            return entry


def check_passwd(name, passwd, pwfile=None):
    """Validate given user, passwd pair against password database."""
    
    if not pwfile or type(pwfile) == type(''):
        getuser = lambda x,pwfile=pwfile: getpwnam(x,pwfile)[1]
    else:
        getuser = pwfile.get_passwd

    try:
        enc_passwd = getuser(name)
    except (KeyError, IOError):
        return 0
    if not enc_passwd:
        return 0
    elif len(enc_passwd) >= 3 and enc_passwd[:3] == '$1$':
        salt = enc_passwd[3:string.find(enc_passwd, '$', 3)]
        return enc_passwd == passcrypt(passwd, salt=salt, method='md5')
    else:
        return enc_passwd == passcrypt(passwd, enc_passwd[:2])


def login(name=None, pwdb=None, user_prompt=None, pass_prompt=None,
  max_tries=3):
    """Generate a login screen and validate the login.

    pwdb - user database object providing a check_passwd(user, password) 
           method that returns true/false. If not given or pwdb is a string
           the implementation from this module is used.

    Returns user name or None on failure.
    """

    if not pwdb or type(pwdb) == type(''):
        user_prompt = 'Login: '
        pass_prompt = 'Password: '
        check = lambda x,y,pwdb=pwdb: check_passwd(x,y,pwdb)
    else:
        user_prompt = getattr(pwdb, 'user_prompt', 'Login: ')
        pass_prompt = getattr(pwdb, 'pass_prompt', 'Password: ')
        check = pwdb.check_passwd
        
    tries = 0
    while tries < max_tries:
        tries += 1
        while not name:
            try:
                name = raw_input(user_prompt)
                break
            except KeyboardInterrupt:
                print
                pass
            except EOFError:
                return None
        while 1:
            try:
                passwd = getpass(pass_prompt)
                break
            except KeyboardInterrupt:
                print
                passwd = None
                break
            except EOFError:
                return None
        if passwd != None and check(name, passwd):
            break
    else:
        return None
    return name


def passcrypt(passwd, salt=None, method='des', magic='$1$'):
    """Encrypt a string according to rules in crypt(3)."""
    
    if method.lower() == 'des':
	if not salt:
	    salt = str(whrandom.choice(DES_SALT)) + \
	      str(whrandom.choice(DES_SALT))

	return crypt.crypt(passwd, salt)
    elif method.lower() == 'md5':
	return passcrypt_md5(passwd, salt, magic)
    elif method.lower() == 'clear':
        return passwd


def _to64(v, n):
    r = ''
    while (n-1 >= 0):
	r = r + DES_SALT[v & 0x3F]
	v = v >> 6
	n = n - 1
    return r

def passcrypt_md5(passwd, salt=None, magic='$1$'):
    """Encrypt passwd with MD5 algorithm."""
    
    if not salt:
	salt = repr(int(time.time()))[-8:]
    elif salt[:len(magic)] == magic:
        # remove magic from salt if present
        salt = salt[len(magic):]

    # salt only goes up to first '$'
    salt = string.split(salt, '$')[0]
    # limit length of salt to 8
    salt = salt[:8]

    ctx = md5.new(passwd)
    ctx.update(magic)
    ctx.update(salt)
    
    ctx1 = md5.new(passwd)
    ctx1.update(salt)
    ctx1.update(passwd)
    
    final = ctx1.digest()
    
    for i in range(len(passwd), 0 , -16):
	if i > 16:
	    ctx.update(final)
	else:
	    ctx.update(final[:i])
    
    i = len(passwd)
    while i:
	if i & 1:
	    ctx.update('\0')
	else:
	    ctx.update(passwd[:1])
	i = i >> 1
    final = ctx.digest()
    
    for i in range(1000):
	ctx1 = md5.new()
	if i & 1:
	    ctx1.update(passwd)
	else:
	    ctx1.update(final)
	if i % 3: ctx1.update(salt)
	if i % 7: ctx1.update(passwd)
	if i & 1:
	    ctx1.update(final)
	else:
	    ctx1.update(passwd)
        final = ctx1.digest()
    
    rv = magic + salt + '$'
    final = map(ord, final)
    l = (final[0] << 16) + (final[6] << 8) + final[12]
    rv = rv + _to64(l, 4)
    l = (final[1] << 16) + (final[7] << 8) + final[13]
    rv = rv + _to64(l, 4)
    l = (final[2] << 16) + (final[8] << 8) + final[14]
    rv = rv + _to64(l, 4)
    l = (final[3] << 16) + (final[9] << 8) + final[15]
    rv = rv + _to64(l, 4)
    l = (final[4] << 16) + (final[10] << 8) + final[5]
    rv = rv + _to64(l, 4)
    l = final[11]
    rv = rv + _to64(l, 2)
    
    return rv