#!/usr/bin/env python

# Jacob Joseph
# 30 December 2008

# Single class for organizing statistics for sets of networks taken
# from DurandDB.  Runs jobs in parallel using pp

import pp, time, sys, os, traceback, logging
import stathelper

# Don't pass class instances as arguments?  They break.
def ppjob(run_id, stype, min_score, set_id_filter, symmetric,
          cacheq, omit_spl, omit_cc, verbose):
    errormsg = None
    retval = None

    try:
        sys.path.append(os.path.expandvars('$HOME/Durand/pylib/JJnetstat'))
        #sys.path.append(os.path.expandvars('$HOME/Durand/pylib/JJnetstat'))
        from JJnetstat import stathelper

        logging.debug("imported stathelper")
        g_stat = stathelper.nxstat_duranddb( run_id = run_id, stype = stype,
                                             min_score = min_score,
                                             set_id_filter = set_id_filter,
                                             symmetric = symmetric,
                                             cacheq = cacheq)
        retval = g_stat.calc_statistics(omit_spl=omit_spl, omit_cc=omit_cc,
                                        verbose=verbose)
        
    except:
        # grab exception text
        errormsg = (str( sys.exc_info()[0]) + '\n'
                    + str(sys.exc_info()[1]) + '\n'
                    + reduce(lambda a,b:a+b, traceback.format_tb(sys.exc_info()[2])))

    return (retval, errormsg)

class stat_set_base:
    """Network statististics calculation"""
    networks = None
    thresholds = None
    net_stat = None

    # networks = [ (stype, run_id, set_id_filter, symmetric, 'descr'), ... ]
    # thresholds = { stype: (0.05, ...), stype: (29, ...) }
    def __init__(self, networks, thresholds):
        self.networks = networks
        self.thresholds = thresholds

        self.build_netstats()

    def build_netstats(self):
        self.net_stat = netstatclass()

        for (stype, run_id, set_id_filter, symmetric, descr) in self.networks:
            for thresh in self.thresholds[stype]:
                self.net_stat.add_set( stype, run_id, thresh, set_id_filter, symmetric, descr)
        return


class stat_set_debug(stat_set_base):
    """Non-parallel network statistics calculation.  Primarily for debugging"""

    def __init__(self, networks, thresholds):
        stat_set_base.__init__(self, networks, thresholds)


    # It's slow to rebuild the graph for every threshold.  Instead,
    # initialize a weighted graph that has a loose threshold, and
    # progressively remove edges.  The weighted graph takes very
    # little additional space, and shouldn't affect the algorithms for
    # CC or density.

    def calculate_statistics(self, omit_spl=False, omit_cc=False,
                             cacheq=False, verbose=False):

        for (stype, run_id, set_id_filter, symmetric, descr) in self.networks:

            # order the thresholds from least stringent to most
            # stringent, to allow edges to be removed from the
            # weighted network
            threshs = sorted(self.thresholds[ stype])

            print time.time(), "Starting run:", stype, run_id, set_id_filter, symmetric, descr

            g_stat = stathelper.nxstat_duranddb( run_id = run_id, stype = stype,
                                                 min_score = threshs[0],
                                                 set_id_filter = set_id_filter,
                                                 symmetric = symmetric,
                                                 cacheq = cacheq)
            for i,thresh in enumerate(threshs):
                print time.time(), "Starting thresh:", thresh
                
                if i > 0:
                    # drop edges below the threshold
                    g_stat.prune_thresh_edges( min_score = thresh)
                    print time.time(), "Done pruning"
                
                stats = g_stat.calc_statistics(omit_spl=omit_spl, omit_cc=omit_cc,
                                               verbose=verbose)

                self.net_stat.set_stat( stats, stype=stype, run_id = run_id,
                                        min_score=thresh, set_id_filter=set_id_filter,
                                        symmetric=symmetric)


    def calculate_statistics_ppjob(self, omit_spl=False, cacheq=False, verbose=False):
        runs = self.net_stat.get_runs()
        while len(runs) > 0:
            run = runs.pop(0)

            print time.time(), "Starting job:", run
            (retval, errormsg) = ppjob( run.run_id, run.stype,
                                        run.min_score, run.set_id_filter,
                                        run.symmetric,
                                        cacheq, omit_spl, verbose)

            if errormsg is None:
                print time.time(), "Job complete"
                self.net_stat.set_stat(run = run, stat_set = retval)
            else:
                print time.time()
                #print "retval:", retval
                print "errormsg:", errormsg
                print "Run failed.  Queuing", run
                #runs.append(run)


        

class stat_set_pp(stat_set_base):
    """Parallel (with pp) network statistics calculation
Servers are hard-coded.  Start pp on each machine with
'ppserver.py -ssjdfl3wq893267yujrf7948o3w7ufds'"""

    # networks = [ (stype, run_id, 'descr')
    # thresholds = { stype: (0.05, ...), stype: (29, ...) }
    def __init__( self, networks, thresholds, ppservers=None):
        stat_set_base.__init__(self, networks, thresholds)

        if ppservers is None: ppservers = ('goby.compbio.cs.cmu.edu',
                                           'ciona.compbio.cs.cmu.edu',
                                           'hagfish.compbio.cs.cmu.edu',
                                           'mordax.compbio.cs.cmu.edu',
                                           'fugu.compbio.cs.cmu.edu',
                                           'cynops.compbio.cs.cmu.edu',
                                           #'diatom.compbio.cs.cmu.edu'
                                           )
        self.job_server = pp.Server(ppservers = ppservers,
                                    secret='sjdfl3wq893267yujrf7948o3w7ufds',
                                    #ncpus = 0
                                    )

        print "Starting pp with", self.job_server.get_ncpus(), "workers."


        #self.queue_jobs(True, True, True)


    def queue_jobs(self, omit_spl, omit_cc, cacheq, verbose):
        self.jobs = {}
        self.job_attempts = {}
        
        for run in self.net_stat.get_runs():
            self.jobs[run] = self.job_server.submit( func=ppjob,
                                                args=(run.run_id, run.stype,
                                                      run.min_score, run.set_id_filter,
                                                      run.symmetric,
                                                      cacheq, omit_spl, omit_cc, verbose),
                                                modules=('sys','os','traceback','logging'))

            self.job_attempts[run] = 1
        

    def calculate_statistics(self, omit_spl=False, omit_cc=False,
                             cacheq=False, verbose=False):

        self.queue_jobs(omit_spl, omit_cc, cacheq, verbose)

        print "waiting.", time.asctime()
        cnt = 0
        last_cnt = -1
        while len(self.jobs) > 0:
            for (run, job) in self.jobs.items():
                if job.finished:
                    retup = job()
                    #print run
                    #print job
                    #print retup
                    if retup is None or len(retup) != 2:
                        print retup
                    (retval, errormsg) = retup
                    if errormsg is None:
                        self.net_stat.set_stat(run = run, stat_set = retval)
                        del self.jobs[run]
                        cnt += 1
                    else:
                        # We probably failed with a temporary error.
                        # Just resubmit the job a few times
                        print "-----WARNING: Job failed after %d attempts----" % self.job_attempts[run]
                        print run
                        print errormsg
                        if self.job_attempts[run] < 3:
                            print "Retrying"
                            self.jobs[run] = self.job_server.submit(
                                func=ppjob,
                                args=(run.run_id, run.stype,
                                      run.min_score, run.set_id_filter,
                                      run.symmetric,
                                      cacheq, omit_spl, omit_cc, verbose),
                                modules=('sys','os'))

                            self.job_attempts[run] += 1
                        else:
                            print "Too many retries; setting return val to None"
                            self.net_stat.set_stat(run=run, stat_set=None)
                            del self.jobs[run]
                            cnt += 1
                        print "--------------------------------------------"
                    
            if last_cnt != cnt:
                print "%d jobs complete, %d remaining" % (cnt, len(self.jobs))
                self.job_server.print_stats()
                
            last_cnt = cnt
            time.sleep(10)

        print "jobs done", time.time()
        self.job_server.print_stats()
        
        return


class run_param( object):
    stype = None
    run_id = None
    min_score = None
    set_id_filter = None
    symmetric = None
    descr = None
    
    def __init__(self, stype, run_id, min_score, set_id_filter,
                 symmetric, descr):
        self.stype = stype
        self.run_id = run_id
        self.min_score = min_score
        self.set_id_filter = set_id_filter
        self.symmetric = symmetric
        self.descr = descr

    def __str__(self):
        s = "<run_param: "
        s += str((self.stype, self.run_id, self.min_score, self.set_id_filter, self.symmetric))
        s += ", '%s'>" % self.descr
        return s

    def __repr__(self):
        return self.__str__()
    
    def __cmp__(self, other):
        """Compare only the run parameters, not the description"""
        
        return cmp( (self.stype, self.run_id, self.min_score, self.set_id_filter, self.symmetric),
                    (other.stype, other.run_id, other.min_score, other.set_id_filter, other.symmetric))
    
    def __hash__(self):
        """Hash only the run parameters, not the description"""
        
        return hash( (self.stype, self.run_id, self.min_score, self.set_id_filter, self.symmetric) )
        

class netstatclass( object):
    run_dict = None
    run_keys = None
    
    def __init__(self):
        self.run_dict = {}
        self.run_keys = {}

    def add_set(self, stype, run_id, min_score, set_id_filter, symmetric,
                descr, stat_set = None):
        run = run_param( stype, run_id, min_score, set_id_filter, symmetric, descr)
        self.run_dict[ run ] = stat_set
        self.run_keys[ (stype, run_id, min_score, set_id_filter, symmetric) ] = run
        return

    def set_stat(self, stat_set, run=None, stype=None, run_id=None, min_score=None,
                 set_id_filter=None, symmetric=None):
        if run is None:
            run = self.get_run(stype, run_id, min_score, set_id_filter, symmetric)  
        self.run_dict[ run ] = stat_set

    def get_run(self, stype, run_id, min_score, set_id_filter,
                symmetric, descr = None):
        if (stype, run_id, min_score, set_id_filter, symmetric) in self.run_keys:
            return self.run_keys[ (stype, run_id, min_score, set_id_filter, symmetric) ]
        else:
            return None

    def get_stat(self, stat_name, run=None, stype=None, run_id=None, min_score=None,
                 set_id_filter = None, symmetric=None):
        if run is None: run = self.get_run( stype, run_id, min_score,
                                            set_id_filter, symmetric)
        stat_dict = self.run_dict[ run]
        if stat_name in stat_dict: return stat_dict[stat_name]
        else: return None

    def get_stypes(self, runs=None):
        stypes = set()
        if runs is None: runs = self.run_dict.keys()
        for run in runs:
            if not run.stype in stypes:
                stypes.add( run.stype)
        return stypes

    def get_runs(self, stype_run_id=None):
        
        if stype_run_id is not None:
            runs = []
            for run in self.run_dict:
                if (run.stype, run.run_id, run.set_id_filter,
                    run.symmetric) == stype_run_id:
                    runs.append(run)
        else:
            runs = self.run_dict.keys()
        return runs

