#!/usr/bin/env python

import commands, ftplib, getopt, glob, os, shutil, string, sys, tempfile

os.environ['PATH'] = '/usr/local/sbin:/usr/sbin:/sbin:' + \
  '/usr/local/bin:/usr/bin:/bin'

DEBUG = 0

lib_dir = '/usr/local/lib/uvscan'
uvscan_bin = commands.getoutput("which uvscan 2>/dev/null")
if uvscan_bin:
    try:
        lib_dir = os.path.dirname(os.readlink(uvscan_bin))
    except:
        pass
        
# adjust the following lines to your needs
cache_dir = '/home/ms/public/nai/linux'
ftp_host = 'ftpde.nai.com'
ftp_path = 'pub/antivirus/datfiles/4.x'

dat_file = None
passive_ftp = 1

__progname__ = "uvscan_update"

__author__ = "Christopher Arndt <chris.arndt@web.de>
__version__ = "1.2"
__date__ = "27.03.2002"

__usage__ = """Usage: %(__progname__)s [OPTIONS]
Version: %(__version__)s

Checks for new DAT files on the NAI FTP server and downloads an installs them

-l --libdir <dir>    Directory where uvscan is installed [%(lib_dir)s]
-f --ftphost <host>  FTP Server where updates are available [%(ftp_host)s]
-u --dat-file <file> Use local DAT file <file>. Don't check for newer versions
-h --help            Show this help
"""

class error(Exception):
    pass

class myFile:
    """Wrapper for file object that writes hash marks to stderr on writing.
    """

    def __init__(self, fn, mode='r'):
        self.fp = open(fn, mode)
        self.tot_written = 0
        self._bytes_out = 0
        self._chunk_size = 10240
    
    def write(self, s):
        self.tot_written = self.tot_written + len(s)
        more_bytes = self.tot_written - self._bytes_out
        if more_bytes >= self._chunk_size:
            chunks = int(more_bytes / self._chunk_size)
            sys.stderr.write("#" * chunks)
            sys.stderr.flush()
            self._bytes_out = self._bytes_out + chunks * self._chunk_size
        self.fp.write(s)
    
    def close(self):
        self.fp.close()
    
    def __del__(self):
        self.fp.close()

    def __getattr__(self, name):
        if hasattr(self.fp, name):
            return getattr(self.fp, name)
        raise NameError, name


def usage(d = vars()):
    sys.stderr.write(__usage__ % d)

def warn(*args):
    sys.stderr.write(string.join(map(str, args)) + '\n')

def debug(*args):
    if DEBUG:
        apply(warn, args)

def die(msg, exit_code=1):
    warn(msg)
    sys.exit(exit_code)

def get_current_version(lib_dir):
    """Get version of installed DAT file."""

    try:
        r, out = commands.getstatusoutput(os.path.join(lib_dir, 'uvscan') + \
         ' --version')
        if r == 0:
            for l in string.split(out, '\n'):
                if string.find(l, 'Virus data file') == 0:
                    return int(string.split(l)[3][1:])
    except:
        pass
    raise error, "Error getting version information"
        
def get_new_version(ftp_host, ftp_path):
    """Get version number of DAT file available on ftp site."""

    try:
        debug("Connecting to host:", ftp_host)
        ftp = ftplib.FTP(ftp_host)
        if passive_ftp:
            ftp.set_pasv(1)
        else:
            ftp.set_pasv(0)
        ftp.login()
        debug("CWD to", ftp_path)
        ftp.cwd(ftp_path)
    except:
        raise error, "Error in ftp transaction. Use option -h <host> ?"
    try:
        fl = ftp.nlst('dat-*.tar')
    except:
        raise error, "No DAT file found"
    ftp.quit()
    fl.sort()
    return int(fl[-1][4:8])

def get_dat_file(version, filename):
    """Retrieve DAT file archive from ftp server."""

    try:
        debug("Connecting to host:", ftp_host)
        ftp = ftplib.FTP(ftp_host)
        if passive_ftp:
            ftp.set_pasv(1)
        else:
            ftp.set_pasv(0)
        ftp.login()
        debug("CWD to", ftp_path)
        ftp.cwd(ftp_path)
    except:
        raise error, "Error in ftp transaction. Use option -h <host> ?"
    fp = myFile(filename, 'wb')
    #fp._chunk_size = 51200
    try:
        try:
            ftp.retrbinary('RETR dat-%i.tar.gz' % (version,), fp.write, 10240)
            debug("Server supports on-the-fly compression")
        except ftplib.error_perm:
            ftp.retrbinary('RETR dat-%i.tar' % (version,), fp.write, 10240)
    except:
        raise error, "Error during FTP download"
    debug("Saved DAT file to", filename)
    print "\nTotal bytes written:", fp.tot_written
    fp.close()
    
def main(argl):
    global DEBUG, lib_dir, ftp_host, ftp_path, dat_file

    try:
        opts, args = getopt.getopt(argl, 'dhl:f:u:', 
          ['help', 'libdir=', 'ftphost='])
    except getopt.GetoptError, msg:
        warn(msg)
        usage()
        sys.exit(1)
    
    for o,a in opts:
        if o in ['-l', '--libdir']:
            lib_dir = a
        elif o in ['-f', '--ftphost']:
            ftp_host = a
        elif o in ['-h', '--help']:
            usage()
            sys.exit(0)
        elif o in ['-u', '--dat-file']:
            dat_file = a
        elif o == '-d':
            DEBUG = 1
    
    # get version numbers, compare them -> update necessary?
    try:
        cur_ver = get_current_version(lib_dir)
    except error, msg:
        debug(msg)
        die('Could not determine installed version. Use option -l <libdir> ?')
    print 'Installed version of uvscan:', cur_ver

    if dat_file == None:
        try:
            print "Checking for new version..."
            new_ver = get_new_version(ftp_host, ftp_path)
        except error, msg:
            die(msg)
    else:
        try:
            new_ver = int(os.path.basename(dat_file)[4:8])
        except NameError:
            die("Could not determine version number. " + \
                "Please name file dat-????.tar(.gz)")

    debug('Available version: ', new_ver)

    if cur_ver < new_ver:
        print "Newer DAT file detected! Version:", new_ver
        tmpfile = tempfile.mktemp('.tar')
        fl = glob.glob(os.path.join(cache_dir, 'dat-%s.tar*' % (new_ver)))
        if dat_file != None:
            shutil.copy(dat_file, tmpfile)
        elif fl:
            shutil.copy(fl[0], tmpfile)
            # XXX Kludge!
            dat_file = 1
            print "Using cached DAT file."
        else:
            # retrieve DAT archive to temporary file
            try:
                print "Retrieving new DAT file, please wait..."
                get_dat_file(new_ver, tmpfile)
            except error, msg:
                debug(msg)
                die("Error downloading DAT file, quitting...")
    else:
        print "No new DAT file available, nothing to do..."
        sys.exit()

    # unpack to temporary directory
    tmpdir = tempfile.mktemp()
    os.mkdir(tmpdir, 0700)
    ftype = commands.getoutput("file -b %s" % tmpfile)
    if string.find(ftype, 'gzip') != -1:
        os.system("gzip -dc %s | tar xf - -C %s" % (tmpfile, tmpdir))
    else:
        os.system("tar xf %s -C %s" % (tmpfile, tmpdir))

    # install *.dat file to library dir
    print "Installing new DAT files..."
    fl = glob.glob(tmpdir + '/*.[Dd][Aa][Tt]')
    install_error = 0
    for src in fl:
        dst = os.path.join(lib_dir, string.lower(os.path.basename(src)))
        try:
            os.system('cp -f "%s" "%s"' % (src, dst))
            #shutil.copy(src, dst)
            #os.chmod(dst, 0664)
        except (IOError, os.error), msg:
            warn("Error installing '%s'" % dst, msg)
            install_error = 1

    # check if new version is installed correctly
    new_ver2 = get_current_version(lib_dir)
    if new_ver2 == cur_ver:
        warn("Still old version installed!",
        "\nExpected: %i, Reported: %i" % (new_ver, new_ver2),
        "\nInstall *.dat files manually from directory '%s'." % tmpdir)
        install_error = 1

    # clean up
    try:
        if not install_error:
            shutil.rmtree(tmpdir, 1)
        # keep downloaded archive
        if not dat_file:
            try:
                if string.find(ftype, 'gzip') != -1:
                    shutil.copy(tmpfile, os.path.join(cache_dir, 
                      'dat-%s.tar.gz' % new_ver))
                else:
                    os.system('gzip -c "%s" > "%s"' % (tmpfile, 
                      os.path.join(cache_dir, 'dat-%s.tar.gz' % new_ver)))
            except:
                pass
        os.remove(tmpfile)
    except:
        pass

    if install_error:
        sys.exit(1)
    else:
        print "Success."
        sys.exit(0)
        
    
if __name__ ==  '__main__':
    main(sys.argv[1:])