#!/local/sparc/bin/python
# $Id: cprofile,v 1.16 1998/09/15 18:28:06 ron Exp $
#
# Ron Klatchko
# UCSF Library and Center for Knowledge Management
# Copyright (c) 1997-1998 UC Regents. All rights reserved.
#


import selectserver
import socket
import string
import soundex
import sys
import signal
import glob
import fnmatch
import getopt
import os
import traceback

import multishelve
import stringex

#
# Reloading state variable
# If a request is not being handled, a reload can occurr immediately.
# Otherwise, we need to defer it until the request is done.
#
handlingrequest = 0
needreload = 0


#
# Are we using multithreading or select round-robining
#
# The code that reloads the underlying databases does not handle
# multithreading.  That needs to be fixed to reallow multithreading.
#
if sys.modules.has_key( "server" ):
    # masterserver = server.ThreadingServer
    # socketconnectionparent = server.SocketConnection
    # run = server.loop
    raise AssertionError( "multithreading does not properly handle reloading database" )
elif sys.modules.has_key( "selectserver" ):
    masterserver = selectserver.server
    socketconnectionparent = selectserver.channel
    run = selectserver.asyncore.loop
    

_debuglevel = 0
def debuglevel( n ):
    return n <= _debuglevel


def intersect( a, b ):
    new = []
    for e in a:
	if e in b:
	    new.append( e )
    return new

def subtract( a, b ):
    new = []
    for e in a:
	if e not in b:
	    new.append( e )
    return new
	    

def toalphanum( str ):
    new = []
    for c in str:
	if ( c in string.letters ) or ( c in string.digits ):
	    new.append( c )
    return string.join( new, '' )
	    
	    

class CPSocketConnection( socketconnectionparent ):
    skipfuncs = [ "__init__", "handle_finish" ]

    def __init__( self, _socket ):
	socketconnectionparent.__init__( self, _socket )
	self.handler = CPHandler( self )

    def handle_request( self, line ):
	args = stringex.split( line )
	if not args:
	    return

        global handlingrequest
        global needreload

        handlingrequest = 1

	if hasattr( self.handler, args[0] ) and \
	   not args[0] in CPSocketConnection.skipfuncs:
	    try:
		getattr( self.handler, args[0] )( args[1:], self )
	    except:
		self.writeln( "402 Other %s %s" % (sys.exc_type,sys.exc_value) )
		traceback.print_exc()
	else:
	    self.writeln( "403 Bad request: %s" % (args[0],) )

        handlingrequest = 0
        if needreload:
            needreload = 0
            if debuglevel(2):
                print "reloading databases..."
            loaddb()
            if debuglevel(1):
                print "server continuing..."


    def handle_finish( self ):
	self.handler.handle_finish( self )

class CPHandler:
    def __init__( self, channel ):
	pass

    def lookup( self, args, channel ):
	if len(args) < 3:
	    channel.writeln( "404 Not enough arguments" )
	    return

	fname = args[0]
	lname = string.join(string.split(args[1]))
	secrets = args[2:]

	#
	# Get records that are similiar to the first
	# name
	#
	fname_soundex = soundex.get_soundex( fname )
	if fnameidx.has_key( fname_soundex ):
	    fname_possibilities = fnameidx[ fname_soundex ]
	else:
	    fname_possibilities = []
	if debuglevel(3):
	    print "fname_possibilities: %s" % fname_possibilities

	#
	# Get records the exactly match the last name
	#
	lname_lower = string.lower( lname )
	if lnameidx.has_key( lname_lower ):
	    lname_possibilities = lnameidx[ lname_lower ]
	else:
	    lname_possibilities = []
	if debuglevel(3):
	    print "lname_possibilities: %s" % lname_possibilities

	#
	# Get all the records that exactly match the last name and
	# that are similiar to the first name
	#
	primary_possibilities = intersect( lname_possibilities, fname_possibilities )

	#
	# See if any of those records match the secret
	#
	good = []
	if debuglevel(3):
	    print "Checking primary_possibilities: %s" % primary_possibilities
	for key in primary_possibilities:
	    try:
		s = db[ key ]
	    except ImportError:
		print "ImportError while accessing %s" % (key,)
		continue
	    except KeyError:
		print "KeyError while accessing %s" % (key,)
		continue
	    except:
		print "An error occured while accessing %s" % (key,)
		continue

	    try:
		if s.match( secrets ):
		    good.append( (key,s) )
	    except:
		print "An error occured while doing a match with class %s" % (s.__class__.__name__,)

	#
	# If we didn't find any matches, check all the records that exactly
	# match the last name but don't sound similiar to the first name.
	#
	if len(good) == 0:
	    secondary_possibilities = subtract( lname_possibilities, primary_possibilities )
	    if debuglevel(3):
		print "Checking secondary_possibilities: %s" % secondary_possibilities
	    for key in secondary_possibilities:
		try:
		    s = db[ key ]
		except ImportError:
		    print "ImportError while accessing %s" % (key,)
		    continue
		except KeyError:
		    print "KeyError while accessing %s" % (key,)
		    continue
		except:
		    print "An error occured while accessing %s" % (key,)
		    continue

		try:
		    if s.match( secrets ):
			good.append( (key,s) )
		except:
		    print "An error occured while doing a match with class %s" % (s.__class__.__name__,)

	if len(good) == 0:
	    channel.writeln( "400 Bad info" )
	elif len(good) == 1:
	    key,s = good[0]
	    channel.writeln( "200 Authenticated %s %s %s %s" %
			     (s.description(), key, s.firstname, s.lastname) )
	else:
	    channel.writeln( "401 Ambiquous" )

    def descriptions( self, args, channel ):
	for module in modules.values():
	    channel.writeln( str( module.description() ) )
	channel.writeln( "." )
	    

    def handle_finish( self, channel ):
	pass


#-----------------------------------------------------------------
# initdb()
#
# Initializes the databases.  Makess our databases ready to be
# used by the system.  Should be called only once before any
# processing begins.
#-----------------------------------------------------------------
def initdb():
    global db;
    global fnameidx;
    global lnameidx;
    db = multishelve.multishelve()
    fnameidx = multishelve.multiindex()
    lnameidx = multishelve.multiindex()

    loaddb()


#-----------------------------------------------------------------
# loaddb()
#
# Loads all databases into our system.  This function can be
# called multiple times.
#-----------------------------------------------------------------
def loaddb():
    db.closeall()
    fnameidx.closeall()
    lnameidx.closeall()

    for type in glob.glob( os.path.join( dbdir, "*_db"  ) ):
	prefix = os.path.basename(type)[:3]
	if debuglevel(2):
	    print prefix
	try:
	    db.open( prefix, os.path.join( type, "master.db"), "r" )
	    fnameidx.open( prefix, os.path.join( type, "firstname.idx"), "r" )
	    lnameidx.open( prefix, os.path.join( type, "lastname.idx"), "r" )
	    if debuglevel(2):
		print "Loaded database %s..." % (type,)

	except:
	    print "Could not open all files for %s..." % (type,)
	    db.close( prefix, quiet=1 )
	    fnameidx.close( prefix, quiet=1 )
	    lnameidx.close( prefix, quiet=1 )


def signalhandler( signum, stack ):
    if signum == signal.SIGUSR1:
        #
        # Can we handle a database reload right now?
        #
        if handlingrequest:
            global needreload
            needreload = 1
            if debuglevel(2):
                print "deferring reload databases..."
        else:
            if debuglevel(2):
                print "reloading databases..."
            loaddb()
            if debuglevel(1):
                print "server continuing..."
    elif signum == signal.SIGTERM:
	#
	# Make a signal terminate behave just like a control-c
	#
	raise KeyboardInterrupt


if __name__ == "__main__":
    ipport = 8001
    unixport = "/tmp/cp.sock"
    moduledir = "."
    dbdir = "."

    usage = "Usage: cp options\n  --help\n  --ipport=# (default %d)\n  --unixport=path (default %s)\n  --dbdir=path (default %s)\n  --moduledir=path (default %s)\n  --debuglevel=# (default %d)\n" % (ipport,unixport,dbdir,moduledir,_debuglevel)
    try:
	options, args = getopt.getopt( sys.argv[1:], "", [ "help", "ipport=", "unixport=", "moduledir=", "dbdir=", "debuglevel=" ] )
    except getopt.error, err:
	sys.stderr.write( "%s\n" % err )
	sys.stderr.write( usage )
	sys.exit( 1 )

    if len(args) != 0:
	sys.stderr.write( usage )
	sys.exit( 1 )

    for option in options:
	if option[0] == "--help":
	    sys.stderr.write( usage )
	    sys.exit( 1 )
	elif option[0] == "--unixport":
	    unixport = option[1]
	elif option[0] == "--ipport":
	    try:
		ipport = string.atoi( option[1] )
	    except ValueError:
		sys.stderr.write( "ipport must be a number\n" )
		sys.exit( 1 )
	elif option[0] == "--debuglevel":
	    try:
		_debuglevel = string.atoi( option[1] )
	    except ValueError:
		sys.stderr.write( "debuglevel must be a number\n" )
		sys.exit( 1 )
	elif option[0] == "--dbdir":
	    dbdir = option[1]
	elif option[0] == "--moduledir":
	    moduledir = option[1]

    if debuglevel(1):
	print "listening on ipport %d" % (ipport,)
	print "listening on unixport %s" % (unixport,)
	

    masterserver( CPSocketConnection, socket.AF_INET, ('', ipport) )
    masterserver( CPSocketConnection, socket.AF_UNIX, unixport )

    if debuglevel(1):
	print "loading databases..."
    initdb()


    signal.signal( signal.SIGUSR1, signalhandler )
    signal.signal( signal.SIGTERM, signalhandler )

    try:
	if debuglevel(1):
	    print "starting server..."
	run()
    except KeyboardInterrupt:
	print "Exiting..."
