#!/usr/bin/env python

# Jacob Joseph
# 11 June 2008

from DurandDB import pfamq
import cluster_sql
import numpy
from cluster_stat_brohee import brohee_stat

#import cPickle
#from IPython.Shell import IPShellEmbed

class stat(brohee_stat):
    """Extended network clustering statistics.  While often similar to
    Brohee's, these account for limited annotation of the set of
    proteins."""

    def __init__(self, **args):
        brohee_stat.__init__(self, **args)

    def overall_recall(self, family_set=None):
        """Weighted (by family size and cluster representation)
        average of recall"""

        T = self.cont
        row_sums = numpy.sum(T, axis=1)

        if family_set is None:
            family_set = self.family_ord.keys()

        recsum = 0.0
        norm_sum = 0
        for family_abbrev in family_set:
            i = self.family_ord[family_abbrev]
            recsum += row_sums[i] * self.family_wise_recall(
                family_abbrev, weighted_avg=True)
            norm_sum += self.N[i]
            
        # normalize by the size of all families
        rec_wt = recsum / norm_sum
        return rec_wt

    def family_wise_precision(self, family_abbrev):
        """Precision of a family weighted over all clusters"""

        T = self.cont
        i = self.family_ord[family_abbrev]

        nonzero_clusters = numpy.nonzero( T[i,:])[0]
        col_sums = self.M

        prec = numpy.sum( numpy.array( T[i, nonzero_clusters]**2, dtype=float)
                          / col_sums[nonzero_clusters]) / self.N[i]
        return prec

    def overall_precision(self, family_set=None):
        """Weighted (by family size and cluster representation)
        average of precision"""

        T = self.cont
        row_sums = numpy.sum(T, axis=1)

        if family_set is None:
            family_set = self.family_ord.keys()
        
        precsum = 0.0
        norm_sum = 0
        for family_abbrev in family_set:
            i = self.family_ord[family_abbrev]
            precsum += row_sums[i] * self.family_wise_precision( family_abbrev)
            norm_sum += self.N[i]

        # normalize by the size of all families
        prec_wt = precsum / norm_sum
        return prec_wt

    ######################################
    ## F - statistic
    ######################################
    def F(self, family_abbrev, cluster_id):
        """F-statistic: harmonic mean of precision (ppv) and recall (sn)"""

        rec = self.sensitivity( family_abbrev, cluster_id)
        prec = self.precision( family_abbrev, cluster_id, use_true_csize=True)
        F = 2 * prec * rec / (prec + rec)
        return F

    def family_wise_F( self, family_abbrev):
        """Weighted (by cluster representation) average of F statistic over all
        clusters that represent a family"""

        i = self.family_ord[ family_abbrev]
        T = self.cont

        row = T[i,:]
        nonzero_clusters = numpy.nonzero( row)[0]

        F_sum = 0.0
        for j in nonzero_clusters:
            cluster_id = self.cluster_ord_rev[j]
            F_sum += row[j] * self.F(family_abbrev, cluster_id)

        #print family_abbrev, F_sum, numpy.sum(row)

        F_wt = F_sum / numpy.sum(row)
        return F_wt
   
    def overall_F(self, family_set=None):
        """Weighted (by family size and cluster representation)
        average of F statistic"""

        T = self.cont
        row_sums = numpy.sum(T, axis=1)

        if family_set is None:
            family_set = self.family_ord.keys()

        F_sum = 0.0
        norm_sum = 0
        for family_abbrev in family_set:
            i = self.family_ord[family_abbrev]
            F_sum += row_sums[i] * self.family_wise_F( family_abbrev)
            norm_sum += self.N[i]

        # normalize by the size of all families
        F_wt = F_sum / norm_sum
        return F_wt


    ######################################
    ## Accuracy
    ######################################    
    def accuracy( self, family_abbrev, cluster_id):
        """Geometric mean of precision and recall using weighted
        Precision and Recall"""

        rec = self.sensitivity( family_abbrev, cluster_id)
        prec = self.precision( family_abbrev, cluster_id, use_true_csize=True)
        acc = numpy.sqrt( prec * rec)
        return acc

    def family_wise_accuracy( self, family_abbrev):
        """Average of accuracy weighted by family size"""
        
        i = self.family_ord[ family_abbrev]
        T = self.cont

        row = T[i,:]
        nonzero_clusters = numpy.nonzero( row)[0]

        acc_sum = 0.0
        for j in nonzero_clusters:
            cluster_id = self.cluster_ord_rev[j]
            acc_sum += row[j] * self.accuracy(family_abbrev, cluster_id)

        acc_wt = acc_sum / numpy.sum(row)
        return acc_wt

    def overall_accuracy(self, family_set=None):
        """Weighted average of accuracy"""

        T = self.cont
        row_sums = numpy.sum(T, axis=1)

        if family_set is None:
            family_set = self.family_ord.keys()

        acc_sum = 0.0
        norm_sum = 0
        for family_abbrev in family_set:
            i = self.family_ord[family_abbrev]
            acc_sum += row_sums[i] * self.family_wise_accuracy( family_abbrev)
            norm_sum += self.N[i]

        # normalize by the size of all families
        acc_wt = acc_sum / norm_sum
        return acc_wt
    

    ######################################
    ## "Unity"
    ######################################
    def family_unity(self, family_abbrev):
        """Family unity: weighted average of fraction across
        the clusters that represent a family"""

        i = self.family_ord[ family_abbrev]
        T = self.cont
        row = numpy.array(T[i,:], dtype=float)
        row_fraction = row / numpy.sum( row)

        frag_wt = numpy.sum( row_fraction**2)
        return frag_wt


    def overall_unity(self, family_set=None):
        """Average of family unity weighted by family_size"""
        T = self.cont

        if family_set is None:
            family_set = self.family_ord.keys()

        # row fractions
        #row_sums_t = numpy.array( [numpy.sum(T, axis=1)]).transpose()
        #row_fractions = numpy.array(T, dtype=float) / row_sums_t

        #row_wt = numpy.sum(numpy.array(T, dtype=float)**2 / row_sums_t
        #                   ) / numpy.sum(row_sums_t)
        
        frag_sum = 0.0
        norm_sum = 0
        for family_abbrev in family_set:
            i = self.family_ord[ family_abbrev]
            frag_sum += self.N[i] * self.family_unity(family_abbrev)
            norm_sum += self.N[i]
            
        frac = frag_sum / norm_sum

        #if frac != row_wt:
        #    print "inequal row_wt and frac:", row_wt, frac
        return frac
    

    def family_fraction(self, family_abbrev):
        """Family size / size of union of containing clusters"""

        T = self.cont
        i = self.family_ord[ family_abbrev]
        
        nonzero_clusters = numpy.nonzero( T[i,:])[0]
        col_sums = self.M

        frac = float(numpy.sum( T[i,:])) / numpy.sum(col_sums[ nonzero_clusters])
        return frac


    def overall_fraction(self, family_set=None):
        """Weighted (by family size and cluster representation)
        average of family fraction"""

        T = self.cont
        row_sums = numpy.sum(T, axis=1)

        if family_set is None:
            family_set = self.family_ord.keys()

        precsum = 0.0
        norm_sum = 0
        for family_abbrev in family_set:
            i = self.family_ord[ family_abbrev]
            precsum += row_sums[i] * self.family_fraction( family_abbrev)
            norm_sum += row_sums[i]

        # normalize by the size of all families
        prec_wt = precsum / norm_sum
        return prec_wt

    ####################################
    # PFAM statistics
    ####################################
    def build_pfam_maps(self, set_id=79):
        """Return clust_to_pfam, pfam_to_clust, ident_map, promisc_map dictionaries"""
        pfq = pfamq.pfamq( cacheq=True)

        clust_to_pfam = {}
        pfam_to_clust = {}

        for (cluster_id, seq_id_list) in self.cluster_sets.iteritems():
            domains = pfq.get_domains( seq_id_list)
            clust_to_pfam[ cluster_id] = set(domains)
            
            for domain in domains:
                if not domain in pfam_to_clust:
                    pfam_to_clust[domain] = set()
                pfam_to_clust[domain].add(cluster_id)

        ident_map = pfq.get_identity_map()
        promis_map = pfq.get_promiscuity_map(set_id=set_id)
        
        return (clust_to_pfam, pfam_to_clust, ident_map, promis_map)
                
    def pfam_clust_cnt(self):
        """Return an ordered list of domains, and a 3 column numpy array:
        (# clusters containing domain, domain identity, domain promiscuity)"""

        (clust_to_pfam, pfam_to_clust,
         ident_map, promis_map ) = self.build_pfam_maps()

        # We can only consider domains that are in both Uniprot and Pfam22
        pfam22_set = set(ident_map.keys())
        uniprot_set = set(pfam_to_clust.keys())
        domain_list = sorted( pfam22_set.intersection( uniprot_set))

        #print "Ignoring PFAM domains:", pfam22_set.symmetric_difference( uniprot_set)

        cnt_array = numpy.zeros( (len(domain_list), 3), dtype=float)
        for (i,domain) in enumerate(domain_list):
            cnt_array[i] = ( len(pfam_to_clust[domain]),
                             ident_map[domain],
                             promis_map[domain])
        return (domain_list, cnt_array)

                    
    def get_overall_stats(self, family_set=None):
        if not self.cont_init:
            self.build_contingency()
        
        sdict = {}

        if family_set is None:
            pass
        elif isinstance(family_set, (str,)):
            family_set = (family_set,)
        elif isinstance(family_set, (set, tuple, list)):
            pass
        else:
            assert False, "Unrecognized family_set input: %s" % family_set

        if family_set==('ALL',):
            family_set = self.family_ord.keys()
        elif family_set==('ALL-kin',):
            family_set = set( self.family_ord.keys()) - set(('Kinase',))


        # Brohee class doesn't accept family_set parameter
        # All 'overall_*' functions do.
        #if family_set is None:
            # sdict['Clustering SN'] = self.clustering_wise_sn()
            # sdict['Clustering PPV'] = self.clustering_wise_ppv()
            #sdict['Clustering PPVt'] = self.clustering_wise_ppv(
            #    use_true_csize=True)
            # sdict['Accuracy'] = self.brohee_accuracy()
            #sdict['Accuracyt'] = self.brohee_accuracy(use_true_csize=True)
            # sdict['Cluster SEP'] = self.avg_cluster_wise_sep()
            # sdict['Cluster SEPt'] = self.avg_cluster_wise_sep(
            #    use_true_csize=True)
            # sdict['Family SEP'] = self.avg_complex_wise_sep()
            # sdict['Family SEPt'] = self.avg_complex_wise_sep(
            #    use_true_csize=True)
            # sdict['Family SEPst'] = self.avg_complex_wise_sep(
            #    use_true_csize=True, max_single_sep=True)
            # sdict['Clustering SEP'] = self.clustering_wise_sep()
            #sdict['Clustering SEPt'] = self.clustering_wise_sep(
            #    use_true_csize=True)
            # sdict['Clustering SEPst'] = self.clustering_wise_sep(
            #    use_true_csize=True, max_single_sep=True)

        #sdict['Unity'] = self.overall_unity(family_set=family_set)
        sdict['F'] = self.overall_F(family_set=family_set)
        sdict['Precision'] = self.overall_precision(family_set=family_set)
        sdict['Recall'] = self.overall_recall(family_set=family_set)
        #sdict['Fraction'] = self.overall_fraction(family_set=family_set)
        #sdict['Accuracy'] = self.overall_accuracy(family_set=family_set)
        return sdict

    def __str__( self):
        s = ""
        s += "======================================================\n"
        s += "Comment: " + str(self.csql.cluster_comment) + "\n"
        s += "Parameters: " + str(self.csql.cluster_params) + "\n"
        s += "======================================================\n"
        #s += "    Clustering-wise sensitivity: %0.4f\n" % self.clustering_wise_sn()
        #s += "            Clustering-wise PPV: %0.4f\n" % self.clustering_wise_ppv()
        #s += "                       Accuracy: %0.4f\n" % self.brohee_accuracy()
        #s += "Average cluster-wise separation: %0.4f\n" % self.avg_cluster_wise_sep()
        #s += "Average complex-wise separation: %0.4f\n" % self.avg_complex_wise_sep()
        #s += "     Clustering-wise separation: %0.4f\n" % self.clustering_wise_sep()
        #s += "                       Overall F: %0.4f\n" % self.overall_F()
        #s += "-----------------------------------------------------\n"
        s += "Using true cluster size (not column sum):\n"
        s += "            Clustering-wise PPV: %0.4f\n" % self.clustering_wise_ppv(use_true_csize=True)
        s += "                Brohee Accuracy: %0.4f\n" % self.brohee_accuracy(use_true_csize=True)
        s += "              Weighted Accuracy: %0.4f\n" % self.overall_accuracy()
        #s += "Average cluster-wise separation: %0.4f\n" % self.avg_cluster_wise_sep(use_true_csize=True)
        #s += "Average complex-wise separation: %0.4f\n" % self.avg_complex_wise_sep(use_true_csize=True)
        #s += "     Clustering-wise separation: %0.4f\n" % self.clustering_wise_sep(use_true_csize=True)
        s += "             Weighted Precision: %0.4f\n" % self.overall_precision()
        s += "                Weighted Recall: %0.4f\n" % self.overall_recall()
        s += "                     Weighted F: %0.4f\n" % self.overall_F()
        s += "                  Overall Unity: %0.4f\n" % self.overall_unity()
        s += "        Overall Family Fraction: %0.4f\n" % self.overall_fraction()
        #s += "-----------------------------------------------------\n"
        #s += "Using true cluster size and max single cluster separation:\n"
        #s += "Average complex-wise separation: %0.4f\n" % self.avg_complex_wise_sep(use_true_csize=True,
        #                                                                            max_single_sep=True)
        #s += "     Clustering-wise separation: %0.4f\n" % self.clustering_wise_sep(use_true_csize=True)
        s += "-----------------------------------------------------\n"
        s += "                 Total Clusters: %d\n" % len(self.cluster_ord)
        s += " Annotated (non-empty) clusters: %d\n" % self.num_annotated_clusters()
        s += "\n"
        s += "-----------------------------------------------------\n"
        s += "Family-specific statistics:\n"
        s += "  Family  Size  Row Sum  Clusters  Clusters Size     Frac   Unity  PRECwt   RECwt  RECmax       F\n"
        template = "%8s  %4d     %4d      %4d          %5d   %0.4f  %0.4f  %0.4f  %0.4f  %0.4f  %0.4f\n"
        for (family_abbrev,i) in sorted(self.family_ord.items(), key=lambda a:a[0]):
            # sum of lengths of all clusters which represent this family
            clusters_size = 0
            for j in numpy.nonzero( self.cont[i,:])[0]:
                clusters_size += len( self.cluster_sets[ self.cluster_ord_rev[j] ])
            clusters_size2 = numpy.sum( self.M[ numpy.nonzero(self.cont[i,:]) ])
            if clusters_size != clusters_size2:
                print "clusters size incorrect ", family_abbrev, clusters_size, clusters_size2
            
            s += template % (family_abbrev,
                             self.N[i],
                             numpy.sum(self.cont[i,:]),
                             numpy.size(numpy.nonzero(self.cont[i,:])),
                             clusters_size,
                             self.family_fraction( family_abbrev),
                             self.family_unity( family_abbrev),
                             self.family_wise_precision( family_abbrev),
                             self.family_wise_recall( family_abbrev, weighted_avg=True),
                             self.family_wise_recall( family_abbrev),
                             self.family_wise_F( family_abbrev))

            #s += str(numpy.nonzero( self.cont[i,:])) + "\n"
            
        s += "\n"
        s += "-----------------------------------------------------\n"
        s += "Cluster-specific statistics (non-empty clusters only):\n"
        s += "  Cluster   Size  Col Sum    Families    PPVmax\n"
        template = "%9s  %5d    %5d        %4d    %0.4f %s\n"

        # list of (cluster_id, j, c_size) to sort by size
        clusters = [ (cluster_id, j, self.M[j]
                      ) for (cluster_id, j) in self.cluster_ord.items() ]
        clusters.sort(key=lambda a: a[2], reverse=True)

        family_ord_inv = dict( (v,k) for (k,v) in self.family_ord.items())

        row_cnt = 0
        for (cluster_id, j, c_size) in clusters:            
            # only display clusters with at least one annotated member
            if numpy.sum(self.cont[:,j]) < 1:
                continue

            # abbreviate the output
            #row_cnt += 1
            #if row_cnt > 30:
            #    s+= "... (pruned at 30 clusters)\n"
            #    break
            
            s += template % (cluster_id,
                             self.M[j],
                             numpy.sum(self.cont[:,j]),
                             len(numpy.nonzero(self.cont[:,j])[0]),
                             self.cluster_wise_ppv( cluster_id, use_true_csize=True),
                             [family_ord_inv[i] for i in self.cont[:,j].nonzero()[0]]
                             )
        s += "======================================================\n"
                
        return s

        
class hclust_stat(stat):
    """Clustering statistics for hierarchical clustering"""

    def __init__(self, cr_id, **args):
        self.cr_id = cr_id
        self.csql = cluster_sql.hcluster( cluster_run_id = self.cr_id, cacheq=True)
        stat.__init__(self, **args)

    def fetch_cut(self, distance, set_id_filter=None):
        assert set_id_filter is None, "set_id_filter no longer used.  The set_id in cluster_stat_brohee will always be used."

        self.cluster_sets = self.csql.cut_tree( distance,
                                                set_id_filter = self.set_id)
        self.cont_init = False
        # self.build_contingency()

    def fetch_cut_components(self, distance, br_id, nc_id, stype):
        self.cluster_sets = self.csql.cut_tree_components( distance, br_id, nc_id, stype)
        self.cont_init = False
        #self.build_contingency()


class flatclust_stat( stat):
    """Clustering statistics for flat, non-hierarhical clustering"""

    def __init__(self, cr_id, **args):
        self.cr_id = cr_id
        self.csql = cluster_sql.flatcluster( cluster_run_id = self.cr_id, cacheq=True)
        stat.__init__(self, **args)

        self.cluster_sets = self.csql.fetch_clusters()
        self.cont_init = False
        #self.build_contingency()
        

class brohee_test(stat):
    """A test of the statistics code, using data from Brohee, Table 4"""

    def __init__(self):
        self.cont = numpy.array( [[7, 0, 0, 0, 0],
                                  [0, 6, 8, 0, 0],
                                  [0, 0, 0, 14, 3],
                                  [0, 0, 0, 4, 5]], dtype=int)

        self.cont_init = True
        self.N = numpy.array([7, 14, 20, 8], dtype=int)
        self.M = numpy.array([7, 6, 8, 18, 8], dtype=int)

        self.family_ord = {'complex_1': 0, 'complex_2': 1,
                           'complex_3': 2, 'complex_4': 3 }
        self.cluster_ord = {'cluster_1': 0, 'cluster_2': 1,
                            'cluster_3': 2, 'cluster_4': 3,
                            'cluster_5': 4}



def dbg_sets(family_abbrev, init=False, cr_id=25, nc_dist=0.9):
    if init:
        h = hclust_stat( cr_id)
        h.fetch_cut( nc_dist)
        h.build_contingency()
        
    i = h.family_ord[family_abbrev]
    
    for (cluster_id, j) in h.cluster_ord.items():
        (l, seq_ids) = h.shared_count(family_abbrev, cluster_id)
        if len(seq_ids) > 0:
            print "cluster_id: %d. Length %d" % (cluster_id, l)
            print seq_ids
            
    return h
