#!/usr/bin/env python

# Jacob Joseph
# 2 December 2010

# Tools to consider the correspondence of domains and identified
# clusters:
#
# * Mutual information of a clustering with each domain

import numpy
from math import log
#from DurandDB import familyq, pfamq
#from JJcluster import cluster_sql

class mutualinfo(object):
    
    """clustering: dictionary of clusters mapping to a list/set of
    seq_ids.  Each seq_id must occur in exactly one cluster.

    annotation: dictionary of annotations (e.g., domains) mapping
    to a list/set of seq_ids.  Here, a seq_id may be found in several
    annotation keys.

    Mutual information will be calculated between the clusters and
    each annotation key.
    """

    def __init__(self, clustering, annotation):
        self.cluster_to_seq = clustering
        self.annotation_to_seq = annotation

        self.seq_to_cluster = {}
        for cluster_id, seq_ids in self.cluster_to_seq.iteritems():
            for seq_id in seq_ids:
                self.seq_to_cluster[seq_id] = cluster_id

        self.seq_to_annotation = {}
        for ann, seq_ids in self.annotation_to_seq.iteritems():
            for seq_id in seq_ids:
                try:
                    self.seq_to_annotation[seq_id].add(ann)
                except:
                    self.seq_to_annotation[seq_id] = set( (ann,))


    def annotations_in_cluster(self, cluster):
        anns = set()
        for seq_id in self.cluster_to_seq[cluster]:
            anns.update( self.seq_to_annotation.get(seq_id, []))
        return anns

    def clustering_entropy_old(self):

        N = len(self.seq_to_cluster)
        #print "N", N
        
        H = 0
        for cluster in self.cluster_to_seq:
            ck = float(len(self.cluster_to_seq[cluster]))

            #print "ck", cluster, ck, ck / N * log( ck / N, 2)
            
            H += ck / N * log( ck / N, 2)

        return -H

    def clustering_entropy(self):
        sizes = [ len(cluster) for cluster in self.cluster_to_seq.values()]
        A = numpy.array( sizes, dtype=numpy.double)
        
        A = A / numpy.sum(A)
        A = A * numpy.log2(A)
        
        return -numpy.sum(A) # entropy

    def annotation_entropy(self, ann):

        N = len(self.seq_to_cluster)
        dl = float(len(self.annotation_to_seq[ann]))
        
        H = 0
        for cnt in (dl, N-dl):
            #print "cnt", cnt, cnt / N * log( cnt / N, 2)
            
            H += cnt / N * log( cnt / N, 2)

        return -H

    def joint_cluster_ann_count(self, ann):
        """Return a dictionary by cluster of how many sequences have
        the given annotation.  Only non-zero counts are included."""
        joint_cnt = {}

        for seq_id in self.annotation_to_seq[ann]:
            cluster = self.seq_to_cluster[seq_id]
            try:
                joint_cnt[cluster] += 1
            except KeyError:
                joint_cnt[cluster] = 1
        return joint_cnt

    def joint_ann_count(self, annotations, prepend_clustering=False):
        """Return a contingency table of sequences that have these
        domains.  Dimensions are in order of the list of
        annotations."""

        cont = numpy.zeros( [2]*len(annotations), dtype=numpy.int)

        assert len(annotations)==2, "!= 2 dimensions unimplemented"
        # currently works for 2 dimensions only
        # populate the contingency table

        # ann0=t, ann0=t
        ann0, ann1 = annotations
        cont[1,1] = len(self.annotation_to_seq[ann0].intersection(
            self.annotation_to_seq[ann1]))
        cont[1,0] = len(self.annotation_to_seq[ann0]) - cont[1,1]
        cont[0,1] = len(self.annotation_to_seq[ann1]) - cont[1,1]

        cont[0,0] = len(self.seq_to_cluster) - numpy.sum(cont)

        return cont

    def mutual_information_ann(self, annotations, prepend_clustering=False):

        cont = self.joint_ann_count(annotations)
        N = len(self.seq_to_cluster)

        assert N == numpy.sum(cont), "N != sum(cont)"

        mi = 0

        norm_cont = numpy.array(cont, dtype=numpy.float) / N

        # marginals over each axis
        # works only for 2 dimensions
        marginal = [numpy.sum(norm_cont, axis=i) for i in range(len(annotations))]

        # divide each by the marginal
        #tmp = norm_count / marginal[0]
        #tmp = (tmp.transpose() / marginal[1]).transpose()
        tmp = norm_cont / numpy.outer(marginal[0], marginal[1]).transpose()

        tmp = numpy.log2(tmp)

        mi = numpy.nansum(tmp * norm_cont)

        return mi


    def max_cluster_mi_single(self, cluster):
        """Return the maximum possible MI with a cluster.  This is the
        MI of the partitioning of the cluster against a single set of
        all else."""

        assert False, "Use max_cluster_mi instead"
        
        # FIXME: check that this is actually correct
        
        N = float(len(self.seq_to_cluster))
        cnt = float(len(self.cluster_to_seq[cluster]))
        
        mi = cnt / N * log( cnt / N
                            / (cnt / N)
                            / (cnt / N), 2)
        mi += (N-cnt) / N * log( (N-cnt) / N
                                 / ((N-cnt) / N)
                                 / ((N-cnt) / N), 2)
        return mi

    def max_cluster_mi(self, cluster):
        """Return the maximum possible MI with a cluster.  This
        considers the distribution of sizes of all other clusters,
        which is necessary."""

        N = float(len(self.seq_to_cluster))
        cnt = float(len(self.cluster_to_seq[cluster]))

        # only this cluster contains the domain
        mi = cnt / N * log( cnt / N
                            / (cnt / N)
                            / (cnt / N), 2)

        for cluster1 in self.cluster_to_seq:
            if cluster == cluster1: continue

            ck = float(len(self.cluster_to_seq[cluster1]))

            mi += ck / N * log( ck / N
                                / ((N-cnt) / N) # does not contain domain
                                / ( ck / N),          # in domain
                                2)
        return mi
        

    def mutual_information(self, ann):

        N = len(self.seq_to_cluster)
        ann_cnt = float(len(self.annotation_to_seq[ann]))

        joint_cnt = self.joint_cluster_ann_count(ann)

        # sum over all clusters
        mi = 0
        for cluster in self.cluster_to_seq:
            ck = float(len(self.cluster_to_seq[cluster]))

            jointk = float(joint_cnt.get(cluster, 0))
            
            for (cnt, ann_marginal) in ((ck - jointk, N - ann_cnt), # without
                                        (jointk, ann_cnt)):         # with
                                        
                if cnt == 0: continue

                #if jointk > 0:
                #    print "MI", cluster, ck, jointk, cnt, ann_marginal, cnt / N * log( cnt / N
                #                                                                       / (ck / N)
                #                                                                       / (ann_marginal / N), 2)
                
                mi += cnt / N * log( cnt / N
                                     / (ck / N)
                                     / (ann_marginal / N), 2)
                
        return mi

    def mutual_information_specific_array(self):
        """Return a tuple of:
(1) a length |annotations| array Mutual Information of each domain,

(2) a |clusters|X|domains| matrix of the domain-containing
    contribution to SMI for each cluster and domain,

(3) a |clusters|X|domains| matrix of the domain-absence
    contribution to SMI for each cluster and domain,

(4) A dictionary mapping clusters to indicies in these arrays,

(5) A dictionary mapping domains to indicies"""

        # FIXME: This uses quite a lot of memory...
        
        N = len(self.seq_to_cluster)
        JOINT = numpy.zeros( (len(self.cluster_to_seq), len(self.annotation_to_seq)),
                           dtype=numpy.int)
        CK = numpy.zeros( (len(self.cluster_to_seq),1), dtype=numpy.int)

        # array indicies
        ann_index = dict( zip( sorted(self.annotation_to_seq.keys()),
                               range(len(self.annotation_to_seq))
                               ))
        cl_index = dict( zip( sorted(self.cluster_to_seq.keys()),
                              range(len(self.cluster_to_seq))
                              ))

        # cluster sums
        for cluster in self.cluster_to_seq:
            CK[ cl_index[cluster]] = len(self.cluster_to_seq[cluster])

        # joint counts
        for ann in self.annotation_to_seq:
            joint_cnt = self.joint_cluster_ann_count(ann)
            ann_i = ann_index[ann]

            for (cluster,cnt) in sorted(joint_cnt.items()):
                cluster_i = cl_index[cluster]

                JOINT[cluster_i, ann_i] = cnt

        # count of sequences without domain in each cluster
        # probably not needed, but convenient for thinking about
        nJOINT = CK - JOINT

        # domain counts
        DK = numpy.sum(JOINT, axis=0)
        nDK = numpy.sum(nJOINT, axis=0)

        # domain containing specific MI
        posMI = numpy.array(JOINT,dtype=numpy.double)
        posMI = posMI / N * numpy.log2( posMI / CK / DK * N)
        posMI = numpy.nan_to_num(posMI)
        
        # not domain containing specific MI
        negMI = numpy.array(nJOINT,dtype=numpy.double)
        negMI = negMI / N * numpy.log2( negMI / CK / nDK * N)
        negMI = numpy.nan_to_num(negMI)

        # mutual information of each domain.  Sum over all SMI of all clusters
        MI = numpy.sum(posMI + negMI, axis=0)

        return MI, posMI, negMI, cl_index, ann_index
        #return MI, cl_index, ann_index


    def mutual_information_allann(self):
        mi_dict = {}

        for ann in self.annotation_to_seq:
            mi_dict[ann] = self.mutual_information(ann)

        return mi_dict

    def format_table(self, annotation_label=None):

        all_entropy = {}
        all_mi = {}
        for domain in self.annotation_to_seq:
            all_entropy[domain] = self.annotation_entropy(domain)
            all_mi[domain] = self.mutual_information(domain)


        s = "Sequences: %d, Clusters: %d, Domains: %d\n" % (
            len(self.seq_to_cluster),
            len(self.cluster_to_seq),
            len(self.annotation_to_seq))
        s += "Clustering entropy: %g\n" % self.clustering_entropy()
        s += "      log(#seqs,2): %g\n" % log(len(self.seq_to_cluster),2)
        s += "\n"

        s += "Domain                      #dom #clust TotSize    Entropy Mutual Inf.\n"
        s += "----------------------------------------------------------------------\n"


        for (domain,minf) in sorted(all_mi.items(), key=lambda a:a[1], reverse=True):
            joint_cnt = self.joint_cluster_ann_count(domain)
            total_cluster_cnt = sum((len(self.cluster_to_seq[cluster])
                                     for cluster in joint_cnt))
            
            s += "%-11s %-16s%4d%7d%8d %0.4e %0.4e %0.3g\n" % (
                domain,
                annotation_label[domain] if annotation_label is not None else "",
                len(self.annotation_to_seq[domain]),
                len(joint_cnt),
                total_cluster_cnt,
                all_entropy[domain],
                minf,
                minf / all_entropy[domain])

        return s


if __name__ == "__main__":

    
    #pq = pfamq.pfamq()

    #annotation, seq_map = pq.fetch_domain_map(set_id = set_id)

    # Full Mouse and human set

    # Just the human and mouse families
    if False:
        set_id = 109 # ppod human, mouse
        family_set_name = 'ppod_20_clean'
        #family_set_name = 'hgnc_apr10_selected'

        print "family_set_name: %s, set_id: %d\n" % (family_set_name,set_id)
        
        fq = familyq.familyq(family_set_name=family_set_name)
        clustering = fq.fetch_families()[0]
        
        annotation, seq_map = pq.fetch_domain_map(
            set_id = set_id,
            seq_ids = tuple(reduce(
            lambda a,b:a+b, [list(clustering[fam]) for fam in clustering])))


    # Full mouse and human set
    if False:
        set_id = 109   # ppod human, mouse
        cr_id = 472    # clustering using the full 600k set
        nc_score = 0.2 # near the optimal

        print "set_id: %(set_id)d, cr_id: %(cr_id)d, nc_score: %(nc_score)d" % locals()

        clusclass = cluster_sql.hcluster( cluster_run_id = cr_id,
                                          cacheq=True)
        clustering = clusclass.cut_tree( distance = 1-nc_score,
                                         set_id_filter = set_id)
        annotation, seq_map = pq.fetch_domain_map(set_id = set_id)
        

    if False:
        set_id = 109   # ppod human, mouse
        #cr_id = 381    # jaccard clustering, all
        cr_id = 383    # jaccard clustering, human, mouse

        print "jaccard clustering. set_id: %(set_id)d, cr_id %(cr_id)d" % locals()

        clusclass = cluster_sql.flatcluster( cluster_run_id = cr_id)
        clustering = clusclass.fetch_clusters()
        annotation, seq_map = pq.fetch_domain_map(set_id = set_id)

    # test
    clustering = {0: (0,1,2,3,4,5,6,7),
                  1: (8, 9, 10, 11, 12, 13)}
    annotation = {'a': (0,1,2,3),
                  'b': (4,5),
                  'c': (8, 9, 10, 11),
                  'd': (8, 9, 10, 11),
                  'e': (6,7, 12, 13)}

    mi = mutualinfo( clustering, annotation)
    
    
    #domain_entropy = {}
    #domain_mi = {}
    #for domain in annotation:
    #    domain_entropy[domain] = mi.annotation_entropy(domain)
    #    domain_mi[domain] = mi.mutual_information(domain)


    #annotation_label = {}
    #for domain in annotation:
    #    annotation_label[domain] = pq.lookup_domain_name(domain)

    #print mi.format_table( annotation_label)


    

