#!/usr/bin/env python
# Jacob Joseph
# 3/13/2007
# Utilities to calculate graph statistics

import os, logging
import networkx as NX
from DurandDB import blastq
from JJutil import pickler, rate

class nxstat:
    cacheq = None
    pickledir = None
    
    parameters = None
    G = None

    def __init__(self, G=None, cacheq=False, unique_parameters=None):
        self.pickledir = os.path.expandvars("$HOME/tmp/pickles")

        assert NX.__version__=="1.6", "Networkx version %s not tested" % NX.__version__

        self.G = G

        if unique_parameters is None:
            self.cacheq = False
            self.parameters = None
        else:
            self.cacheq = cacheq
            self.parameters = unique_parameters

    ####################################
    # Graph Calculations
    ####################################
    def degree_distribution( self):
        hist = {}
        for deg in self.G.degree_iter():
            if not hist.has_key( deg):
                hist[deg] = 0
            hist[deg] += 1
        return hist

    def total_path_length( self, G):
        totallength = 0

        shortest_path_len = NX.single_source_shortest_path_length

        n_rate = rate.rate( totcount = G.number_of_nodes())
        for n in G:
            n_rate.increment()
            if n_rate.count % 100 == 0:
                print "path:", n_rate
            totallength += sum(shortest_path_len(G,n).values())
        return totallength

    def components_sizedensitypath( self, omit_spl=False):
        """Return:
        the number of components (size>1),
        the number of single components (size==1),
        a dictionary of densities {size: [densities]},
        a dictionary of total all pair path lengths {size: [path lengths]}"""
        if self.cacheq:
            retval = pickler.cachefn(pickledir=self.pickledir,
                                     args = self.parameters)
            if retval: return retval

        num = 0
        num_single = 0
        sizeden = {}
        sizepathl = {}
        for ccomp in NX.connected_components(self.G):
            C_subgraph = self.G.subgraph(ccomp)
            size = len(C_subgraph)

            # ignore all components of size 1
            if size > 1:
                num += 1
                density = NX.density(C_subgraph)

                # density by size
                if not sizeden.has_key(size): sizeden[size] = []
                sizeden[size].append(density)

                # sum of all shortest path lengths by size
                if not omit_spl:
                    pathl = self.total_path_length( C_subgraph)
                    if not sizepathl.has_key(size): sizepathl[size] = []
                    sizepathl[size].append( pathl)
            else:
                num_single += 1
        retval = (num, num_single, sizeden, sizepathl)
        if self.cacheq: pickler.cachefn( retval = retval,
                                    pickledir=self.pickledir,
                                    args = self.parameters)
        return retval

    def clustering_coefficient( self, G=None):
        "return the average clustering coeff, and a dictionary histogram of k: C(k)"
        if self.cacheq and G is None:
            retval = pickler.cachefn(pickledir=self.pickledir,
                                     args = self.parameters)
            if retval: return retval

        if G is None:
            G_arg = self.G
        else:
            G_arg = G
            
        clusterc = NX.cluster.clustering(G_arg)
        degree = G_arg.degree()
        avgsum = 0
        num_ccnodes = 0
        ck = {}
        for (i, c_i) in clusterc.iteritems():
            k = degree[i]
            # only consider nodes with degree > 1
            if k <= 1: continue
            #print "node", i, "degree", k, "c_i", c_i, "avgsum", avgsum, "num_ccnodes", num_ccnodes
            avgsum += c_i
            num_ccnodes += 1
            if not ck.has_key(k): ck[k] = []
            ck[k].append(c_i)
        # average cc for each degree
        for (k,ck_arr) in ck.iteritems():
            ck[k] = float(sum(ck_arr)) / len(ck_arr)

        if num_ccnodes ==0:
            #from IPython.Shell import IPShellEmbed
            #ipshell = IPShellEmbed()
            #ipshell('Zero cc nodes')
            print "No nodes with degree > 1. Cannot calculate clustering coefficient"
            retval = (-1, ck)
        else:
            avg_ci = float(avgsum) / num_ccnodes
            retval = (avg_ci, ck)

        #print "avg_ci", avg_ci

        if self.cacheq and G is None: pickler.cachefn( retval = retval,
                                    pickledir=self.pickledir,
                                    args = self.parameters)
        return retval

    
    def calc_statistics( self, omit_spl=False, omit_cc=False, verbose=False):
        # dictionary of all statistics
        stats = {}

        stats['parameters'] = self.parameters

        assert self.G is not None, "Graph not yet initialized"

        if verbose:
            print " ** Calculating Graph Statistics %s **" % stats['parameters']

        stats['graph_density'] = NX.density(self.G)
        stats['degree_hist'] = self.degree_distribution()

        #print "clustering coefficient"
        if omit_cc:
            stats['graph_mean_cc'] = 0
            stats['cc_hist'] = {}
        else:
            (stats['graph_mean_cc'],
             stats['cc_hist']) = self.clustering_coefficient()
        #print "result", stats['graph_mean_cc']
        
        # connected components
        (ccomp_num,
         ccomp_num_single,
         ccomp_sizedensities,
         ccomp_sizepaths
         ) = self.components_sizedensitypath(omit_spl=omit_spl)

        stats['ccomp_num'] = ccomp_num
        stats['ccomp_num_single'] = ccomp_num_single
        stats['ccomp_sizedensities'] = ccomp_sizedensities
        stats['ccomp_sizepaths'] = ccomp_sizepaths

        # average component size and density (weighted and unweighted)
        if ccomp_num == 0:
            stats['ccomp_avg_size'] = 0
            stats['ccomp_avg_density_w'] = 0
        else:
            sum_cc_size = 0
            sum_cc_edges = 0
            sum_cc_pairs = 0
            for (cc_size,cc_densities) in ccomp_sizedensities.iteritems():
                sum_cc_size += cc_size * len(cc_densities)
                sum_cc_edges += (cc_size * (cc_size-1)) * sum(cc_densities)
                sum_cc_pairs += (cc_size * (cc_size-1)) * len(cc_densities)

            stats['ccomp_avg_size'] = (float(sum_cc_size)
                                       / ccomp_num)
            stats['ccomp_avg_density_w'] = (float(sum_cc_edges)
                                            / sum_cc_pairs)

            # mean path length
            if omit_spl:
                pass
            elif ccomp_num == 0:
                stats['ccomp_avg_spl'] = 0
            else:
                sum_cc_mpls = 0
                
                sum_cc_spls = 0
                sum_cc_pairs = 0
                for (cc_size, cc_pathls) in ccomp_sizepaths.iteritems():
                    sum_cc_spls += sum(cc_pathls)
                    sum_cc_pairs += (cc_size * (cc_size-1)) * len(cc_pathls)
                    
                stats['ccomp_avg_spl'] = float(sum_cc_spls) / sum_cc_pairs
                
                #ipshell = IPShellEmbed()
                #ipshell('mpl')
                
                # cliques
                #
                # FIXME: this could be far more efficient, searching only
                # connected components or cliques found in lower
                # thresholds
                #cnt = 0
                #totalsize = 0
                #for clique in NX.cliques.find_cliques(G[stype][thresh]):
                #    if len(clique) > 1:
                #        cnt += 1
                #        totalsize += len(clique)
                #cliques_num[stype][thresh] = cnt
                #cliques_size[stype][thresh] = float(totalsize) / cnt
            return stats

class nxstat_duranddb(nxstat):

    def __init__(self, run_id=None, br_id=None, nc_id=None, stype=None,
                 min_score=None, set_id_filter=None, cacheq=True,
                 symmetric=True, build_graph=True):
        print "in init"
       

        if run_id is not None:
            if nc_id is None and stype=='nc_score':
                nc_id = run_id
                br_id = None
            elif br_id is None and stype in ('bit_score' or 'e_value'):
                br_id = run_id
                nc_id = None

        self.br_id = br_id
        self.nc_id = nc_id
        self.stype = stype
        self.set_id_filter = set_id_filter
        self.min_score = min_score

        assert (br_id is not None or nc_id is not None), "One of br_id or nc_id must be specified"

        unique_parameters= (str(self.br_id)+'_'+str(self.nc_id)+
                          '_'+str(self.stype)+'_'+
                          str(self.set_id_filter)+'_'+str(self.min_score))

        print "after unique parameters"

        # using a string for pickleargs avoids different hashing
        # between architectures.  The pickles from both architectures
        # seem to be identical.
        nxstat.__init__(self, cacheq=cacheq, unique_parameters=unique_parameters)

        print "calling __init__"

        self.parameters= (str(self.br_id)+'_'+str(self.nc_id)+
                          '_'+str(self.stype)+'_'+
                          str(self.set_id_filter)+'_'+str(self.min_score))

        if build_graph:
            self.build_graph_sql( br_id=br_id, nc_id=nc_id, stype=stype,
                                  min_score=min_score, set_id_filter=set_id_filter,
                                  symmetric=symmetric)
        

    def build_graph_sql( self, br_id, nc_id, stype, min_score, set_id_filter,
                         symmetric=True):
        """Construct a graph by querying the database"""

        #print br_id, nc_id, stype, min_score, set_id_filter, symmetric

        if self.cacheq:
            print "in cacheq"
            try:
                retval = pickler.cachefn(pickledir=self.pickledir,
                                         args = self.parameters)
            except Exception,e:
                assert False, "exception caught: %s" % e
            #print "retval:", retval

            if retval:
                self.G = retval
                return
            print "leaving cacheq"

        # Initialize a database connection
        dbh = blastq.blastq( cacheq=self.cacheq, pickledir = self.pickledir,
                             debug=False)

        # Unpickling seems to produce the same effects as storing a
        # new integer for every node.  Don't ever do this.

        # FIXME: this doesn't appear to be true since switching to
        # postgres, but I can't imagine why.

        # FIXME TOO: Well, maybe pickling works fine with memory, but
        # is actually slower than hitting the database.  296.77 vs
        # 294.81 seconds for NC id 750

        if set_id_filter is not None:
            nodes = dbh.fetch_seq_set( set_id=set_id_filter)
        else:
            nodes = dbh.fetch_seq_set( br_id=br_id, nc_id=nc_id)

        hits = dbh.fetch_hits_direct( br_id = br_id,
                                      nc_id = nc_id,
                                      stype = stype,
                                      thresh = min_score,
                                      set_id_filter = set_id_filter,
                                      symmetric = symmetric,
                                      self_hits = False)

        #G = NX.Graph_list() # too slow for clustering coeff
        G = NX.Graph() # about a 3/2 as much memory

        vertmap = {}
        for n in nodes:
            vertmap[n] = n
            G.add_node(n)

        for (n1,n2,score) in hits:
            # These asserts are very, very slow, decreasing the speed
            # of this loop from about 21041/s to 1118/s
            
            #assert n1 != n2, "Found self match. Is self_hits==False? (%s, %s, %s)" % (
            #    n1,n2,score)
            
            #if (n1 not in nodes) or (n2 not in nodes):
            #    assert False, "Nodes not already in graph: %s, %s, %g" % (n1,n2,score)
            #    continue
            
            # Store each node only once, rather than a new reference
            # for each edge endpoint
            n1 = vertmap[n1]
            n2 = vertmap[n2]

            #print n1,n2,score

            G.add_edge(n1, n2, weight=score)


        print "Done reading. Closing connection"

        print "Graph:", G.number_of_nodes(), G.number_of_edges(), G.number_of_selfloops()
        
        dbh.close()
        
        self.G = G

        if self.cacheq:
            pickler.cachefn(retval = G, pickledir=self.pickledir,
                            args = self.parameters)
        return


    def prune_thresh_edges(self, min_score):
        assert self.G is not None, "self.G is None.  Has graph been defined?"

        self.min_score = min_score

        self.parameters= (str(self.br_id)+'_'+str(self.nc_id)+
                          '_'+str(self.stype)+'_'+str(self.set_id_filter)+'_'+str(self.min_score))

        G = self.G

        # we don't want to modify the graph while iterating
        edge_set = []
        for (n1,n2,edgedata) in G.edges_iter(data=True):
            score = edgedata['weight']
            
            if score > min_score: continue

            edge_set.append( (n1,n2))

        G.remove_edges_from( edge_set)
        
        return

        
