#!/usr/bin/env python

# Jacob Joseph
# 16 July 2008

from DurandDB import familyq
import numpy, random, copy

class brohee_stat(object):
    """Network clustering statistics, from Brohee's paper."""
    set_id = None
    adjust = None       # Randomize the sequence assignment of clusters
                        # before calculating the contingency table.
    seq_map = None      # Dictionary of seq_id to seq_id.  The latter may
                        # be randomized
    family_sets = None  # Dictionary of family: set of seq_ids
    cluster_sets = None # Dictionary of cluster: set of seq_ids

    cont = None         # Contingency table
    cont_init = None    # True if contingency table was initialized
    N = None            # Family sizes. Known at initialization
    M = None            # Cluster sizes.  Not known until a cluster_set is
                        # fetched
    
    def __init__(self, family_set_name=None, randomize_seq_ids=False, set_id=None):
        self.set_id = set_id
        self.adjust = randomize_seq_ids
        
        self.fq = familyq.familyq( family_set_name=family_set_name)
        self.family_sets = self.fq.family_sets

        # Decide on a specific ordering of families and clusters
        self.family_ord = {}
        self.cluster_ord = {}
        for (i,f) in enumerate(sorted(self.family_sets.keys())):
            self.family_ord[f] = i

        # complex sizes (usually equivalent to the row sum)
        # items in family_ord order
        items = sorted(self.family_ord.items(), key=lambda a: a[1])
        sizes = [ len(self.family_sets[a[0]]) for a in items]
        self.N = numpy.array( sizes, dtype=int)

        self.cont_init = False

    def fetch_seq_set_dict(self):
        seq_ids = list(self.fq.fetch_seq_set( set_id = self.set_id))

        seq_ids_rand = copy.copy(seq_ids)
        if self.adjust: random.shuffle( seq_ids_rand)
        
        self.seq_map = dict( zip( seq_ids, seq_ids_rand))
        return

    def set_randomize(self, rand):
        """Force a re-randomization of cluster seq_ids, and rebuild
        the contingency table"""

        if self.adjust != rand:
            self.adjust = rand
            self.seq_map = None
            self.cont = None
            self.cont_init = False

        if not self.cont_init:
            self.build_contingency()

    def build_contingency(self):
        """Construct the numpy contingency table"""

        m = "self.cluster_sets is uninitialized. Call fetch_cut() first"
        assert self.cluster_sets is not None, m

        # A cluster could be empty if the set of sequences is filtered
        # in fetch_cut.  Ignore empty clusters

        self.cluster_ord = {}
        self.cluster_ord_rev = {}

        i = 0
        for c,seq_set in self.cluster_sets.iteritems():
            if len(seq_set) == 0: continue
            
            self.cluster_ord[c] = i
            self.cluster_ord_rev[i] = c
            i += 1

        self.cont = numpy.zeros((len(self.family_sets),
                                 len(self.cluster_ord)),
                                dtype=int)
        
        for (family_abbrev, i) in self.family_ord.items():
            for (cluster_id, j) in self.cluster_ord.items():
                self.cont[i,j] = self.shared_count(
                    family_abbrev, cluster_id)[0]

        # Cluster sizes (col sum + any members not assigned to a family)
        # sorted by index j
        clusters = [a[0] for a in sorted(self.cluster_ord.items(), key=lambda a: a[1])]
        M = [ len(self.cluster_sets[cluster_id]) for cluster_id in clusters]
        self.M = numpy.array(M, dtype=int)

        self.cont_init = True
        return


    def shared_count( self, family_abbrev, cluster_id):
        """Return (count, set(seq_ids)) for a given cell of the
        contingency table"""
        
        m = "self.cluster_sets is uninitialized. Call fetch_cut() first"
        assert self.cluster_sets is not None, m

        # Build sequence map if it doesn't exist
        if self.seq_map is None: self.fetch_seq_set_dict()

        # Possibly translate cluster seq_ids to a random
        cluster_seqs_raw = self.cluster_sets[cluster_id]
        cluster_seqs = [ self.seq_map[seq_id] for seq_id in cluster_seqs_raw]

        seq_ids = self.family_sets[family_abbrev].intersection(
            cluster_seqs)
        
        return (len(seq_ids), seq_ids)


    def recall(self, family_abbrev, cluster_id):
        """Fraction of family in cluster."""
        if not self.cont_init: self.build_contingency()

        i = self.family_ord[family_abbrev]
        j = self.cluster_ord[cluster_id]

        SN = float(self.cont[i,j]) / self.N[i]
        return SN
    sensitivity = recall


    def family_wise_recall(self, family_abbrev, weighted_avg=False):
        """Maximal fraction of proteins of family assigned to the same
        cluster, or weighted average over all clusters"""
        if not self.cont_init: self.build_contingency()
        
        T = self.cont
        i = self.family_ord[family_abbrev]

        row = T[i,:]
        if weighted_avg:
            nonzero_clusters = numpy.nonzero( row)
            complex_SN = float(
                numpy.sum(row[nonzero_clusters]**2)) / self.N[i]**2
        else:
            complex_SN = float(numpy.max( row)) / self.N[i]

        return complex_SN
    complex_wise_sn = family_wise_recall


    def clustering_wise_sn(self):
        """Weighted average of complex-wise SN"""
        if not self.cont_init: self.build_contingency()
        
        T = self.cont
        SN = float( numpy.sum(numpy.max( T, axis=1))) / numpy.sum(self.N)
        return SN


    def precision(self, family_abbrev, cluster_id, use_true_csize=False):
        """Precision (positive predictive value): proportion of
        members of cluster j which belong to family i, relative to the
        total number of members of this cluster."""
        if not self.cont_init: self.build_contingency()
        
        i = self.family_ord[family_abbrev]
        j = self.cluster_ord[cluster_id]

        if use_true_csize:
            cluster_size = self.M[j]
        else:
            cluster_size = numpy.sum(self.cont[:,j])
        
        PPV = float(self.cont[i,j]) / cluster_size
        return PPV


    def cluster_wise_ppv(self, cluster_id, use_true_csize=False):
        """Maximal fraction of proteins of cluster j found in the same
        annotated family"""
        if not self.cont_init: self.build_contingency()

        T = self.cont
        j = self.cluster_ord[cluster_id]

        if use_true_csize:
            cluster_size = self.M[j]
        else:
            cluster_size = numpy.sum(T[:,j])

        if self.M[j] < numpy.sum(T[:,j]):
            print "Cluster %d has greater column sum than cluster size" % j

        cluster_PPV = float(numpy.max( T[:,j])) / cluster_size
        return cluster_PPV


    def clustering_wise_ppv(self, use_true_csize=False):
        """Weighted average of PPVcl_j over all clusters"""
        if not self.cont_init: self.build_contingency()

        T = self.cont

        if use_true_csize:
            clusters_size = numpy.sum(self.M)
        else:
            clusters_size = numpy.sum(self.cont)

        PPV = float( numpy.sum( numpy.max(T, axis=0))) / clusters_size
        return PPV


    def brohee_accuracy(self, use_true_csize=False):
        """Geometrical accuracy of SN and PPV. Calculated using
        clustering (max) statistics from Brohee"""

        acc = numpy.sqrt( self.clustering_wise_sn() *
                          self.clustering_wise_ppv(use_true_csize = use_true_csize))
        return acc

    def complex_wise_sep( self, family_abbrev, use_true_csize=False,
                          max_single_sep=False):
        """Sum of separation values for a given family. Does not
        average over empty columns."""
        if not self.cont_init: self.build_contingency()
        
        i = self.family_ord[ family_abbrev]
        T = self.cont

        row_sum = numpy.sum( T[i,:])
        if use_true_csize:
            col_sums = self.M
        else:
            col_sums = numpy.sum( T, axis=0)

        nonzero_columns = numpy.nonzero( col_sums)

        cluster_seps = numpy.array( T[i,nonzero_columns] * T[i,nonzero_columns],
                                    dtype=float) / col_sums[nonzero_columns]
        if max_single_sep:
            sep_co = numpy.max( cluster_seps) / row_sum
        else:
            sep_co = numpy.sum( cluster_seps) / row_sum

        return sep_co


    def cluster_wise_sep( self, cluster_id, use_true_csize=False):
        """Concentration of one or several complexes within a given
        cluster"""
        if not self.cont_init: self.build_contingency()
        
        j = self.cluster_ord[ cluster_id]
        T = self.cont

        row_sums = numpy.sum( T, axis=1)
        if use_true_csize:
            col_sum = self.M[j]
        else:
            col_sum = numpy.sum( T[:,j])

        complex_seps = numpy.array( T[:,j] * T[:,j], dtype=float) / row_sums

        sep_cl = numpy.sum( complex_seps) / col_sum

        return sep_cl


    def avg_complex_wise_sep(self, use_true_csize=False,
                             max_single_sep=False):
        if not self.cont_init: self.build_contingency()
        
        T = self.cont

        n = len(self.family_ord)
        
        co_sum = 0.0
        for family_abbrev in self.family_ord:
            co_sum += self.complex_wise_sep( family_abbrev,
                                             use_true_csize = use_true_csize,
                                             max_single_sep = max_single_sep)

        return co_sum / n


    def avg_cluster_wise_sep(self, use_true_csize=False):
        """Average cluster-wise separation calculated over annotated
        clusters"""
        if not self.cont_init: self.build_contingency()

        T = self.cont
        m = len(self.cluster_ord)

        cl_sum = 0.0
        for (cluster_id, j) in self.cluster_ord.items():
            if numpy.sum(self.cont[:,j]) < 1:
                m -= 1
                continue
            cl_sum += self.cluster_wise_sep( cluster_id, use_true_csize=use_true_csize)

        return cl_sum / m

    def clustering_wise_sep(self, use_true_csize=False, max_single_sep=False):
        """Geometrical mean of Sep_co and Sep_cl"""

        sep = numpy.sqrt( self.avg_cluster_wise_sep(use_true_csize=use_true_csize) *
                          self.avg_complex_wise_sep(use_true_csize=use_true_csize,
                                                    max_single_sep=max_single_sep))
        return sep

    def num_annotated_clusters(self):
        if not self.cont_init: self.build_contingency()
        
        T = self.cont

        # returns a tuple arrays, one for each dimension
        annotated_cols = numpy.nonzero(numpy.sum(T, axis=0))
        return len(annotated_cols[0])
