#!/usr/bin/env python

# Jacob Joseph
# 19 Dec 2010

# Plot various data about mutual information

import numpy
import matplotlib
import time
from matplotlib import pyplot
from DurandDB import pfamq, familyq
from JJcluster import cluster_sql

from information import mutualinfo

class mi_plot(mutualinfo):

    def scatter_mi_entropy(self, max_labels=10,
                           filename=None,
                           dpi=None,
                           label_exclude=None):
        """scatter plot of mutual information vs entropy for each domain
        color by density?
        """

        domains = sorted(self.annotation_to_seq.keys())
        index = {}      # {domain: i}
        harr = []
        miarr = []
        instdict = {}   # {cnt: set(domains)} number of instances of domain
        
        for i, domain in enumerate(domains):
            index[domain] = i
            
            h = self.annotation_entropy(domain)
            mi = self.mutual_information(domain)
            instances = len(self.annotation_to_seq[domain])

            harr.append(h)
            miarr.append(mi)

            try: instdict[instances].add( domain)
            except KeyError: instdict[instances] = set( (domain,))

        pyplot.scatter( harr, miarr, color='blue', alpha=0.8, s=5, zorder=11)

        #ax = pyplot.gca()
        pyplot.axis('tight')
        #if xlog: pyplot.xscale('log')
        #if ylog: pyplot.yscale('log')


        # draw a square plot
        #lim_max = max(pyplot.xlim()[1], pyplot.ylim()[1])
        #pyplot.xlim( (0, lim_max))
        #pyplot.ylim( (0, lim_max))
        pyplot.xlabel('Entropy (bits)')
        pyplot.ylabel('Mutual Information (bits)')

        pyplot.grid(color='grey', zorder=10)

        # label domains with largest counts
        i = 0
        for instances in sorted(instdict.keys(), reverse=True):
            domains = instdict[instances]
            #print i, instances, domains

            for domain in domains:              
                #print harr[index[domain]], miarr[index[domain]], domain, instances, i

                if label_exclude is None or i not in label_exclude:
                    pyplot.text(
                        harr[index[domain]],
                        miarr[index[domain]],
                        '%s (%d)' % (domain.split('.')[0], instances),
                        size=5,
                        fontsize=8,
                        zorder=12)
                i += 1

                if i >= max_labels: break

            if i >= max_labels: break

        pyplot.subplots_adjust(hspace=0.1, wspace=0.05, bottom=0.09,
                               left=0.12, top=0.98, right=0.9)

        if filename is not None:
            pyplot.savefig( 'figures/%s' % filename, dpi=dpi)
            pyplot.close()
        else:
            pyplot.show()

            
        return domains, harr, miarr, instdict


    def scatter_cluster_mi(self, max_labels=10):
        """Scatter plot of all clusters vs the MI of all domains
        within them.  Plot a 'max MI' point above these, calculated as
        the MI between a cluster and itself.  Order the clusters by
        decreasing size.

        Could add an upper bound (error bar) up to the entropy of each
        domain."""


        
        cluster_sizes = {}
        for cluster, seqs in self.cluster_to_seq.iteritems():

            try:
                cluster_sizes[ len(seqs)].append(cluster)
            except KeyError:
                cluster_sizes[ len(seqs)] = [cluster]

        harr = []
        xlabels = []
        maxmi = []
        maxmi_bad = []
        harr_mi = []
        miarr = []

        # MI of all domains
        annotations_mi = self.mutual_information_allann()
        #print annotations_mi

        i = 0
        for size in sorted(cluster_sizes, reverse=True):
            for cluster in cluster_sizes[size]:

                xlabels.append('%s (%d)' % (cluster, size))
                harr.append(i)
                maxmi.append( self.max_cluster_mi(cluster))
                maxmi_bad.append( self.max_cluster_mi_single(cluster))

                # MI of all contained domains
                for domain in self.annotations_in_cluster(cluster):
                    print i, cluster, domain, annotations_mi[domain]
                    harr_mi.append(i)
                    miarr.append( annotations_mi[domain])

                i += 1

        pyplot.scatter(harr_mi, miarr, color='green', alpha=0.5)
        pyplot.scatter(harr, maxmi, color='red', alpha=0.5)
        pyplot.scatter(harr, maxmi_bad, color='blue', alpha=0.5)
        pyplot.grid(color='grey')
        pyplot.xticks( harr, xlabels, rotation=270)
        pyplot.xlim((-1, len(harr)))
        pyplot.ylim((0, pyplot.ylim()[1]))

        pyplot.xlabel('Cluster')
        pyplot.ylabel('Mutual information (bits)')
        pyplot.legend(('MI (domain)', 'MI max (correct)', 'MI max (partition)'))

        return xlabels, harr, maxmi, harr_mi, miarr

    def scatter_cluster_mi_specific(self, filename=None, dpi=None,
                                    xlabel='Cluster',
                                    adj_bottom=0.2,
                                    xlimit=None,
                                    xticks=True,
                                    grid=True,
                                    pointsize=20, # pyplot default
                                    cluster_list=None,
                                    ymin=None,
                                    draw_legend=True
                                    ):
        """Scatter plot of all clusters vs the MI of all domains
        within them.  Plot a 'max MI' point above these, calculated as
        the MI between a cluster and itself.  Order the clusters by
        decreasing size.

        Could add an upper bound (error bar) up to the entropy of each
        domain."""

        fig=pyplot.figure()

        harr = []
        xlabels = []
        maxmi = []

        harr_mi = []
        miarr = []
        smiarr_positive = []

        if cluster_list is None:
            cluster_sizes = {}
            for cluster, seqs in self.cluster_to_seq.iteritems():
                
                try:
                    cluster_sizes[ len(seqs)].append(cluster)
                except KeyError:
                    cluster_sizes[ len(seqs)] = [cluster]
                    
            cluster_list = [cluster 
                            for size in sorted(cluster_sizes, reverse=True)
                            for cluster in cluster_sizes[size]]

        # MI of all domains
        #annotations_mi = self.mutual_information_allann()
        #print annotations_mi
        MI, posMI, negMI, cl_index, ann_index = self.mutual_information_specific_array()

        i = 0
        for cluster in cluster_list:
            size = len(self.cluster_to_seq[cluster])
            xlabels.append('%s (%d)' % (cluster, size))
            harr.append(i)
            maxmi.append( self.max_cluster_mi(cluster))
            #maxmi_bad.append( self.max_cluster_mi_single(cluster))

            # MI of all contained domains
            for domain in self.annotations_in_cluster(cluster):
                #print i, cluster, domain, annotations_mi[domain]
                harr_mi.append(i)

                smiarr_positive.append( posMI[cl_index[cluster], ann_index[domain]])
                miarr.append( MI[ ann_index[domain]])


            i += 1

        pyplot.scatter(harr_mi, miarr, color='green', alpha=1,
                       label = 'MI', s = pointsize, zorder=10)

        pyplot.scatter(harr_mi, smiarr_positive, marker='+', color='blue', alpha=0.8,
                       label = 'presence SMI', s = pointsize, zorder=11)

        pyplot.scatter(harr, maxmi, edgecolor='red', facecolor='', alpha=1,
                       label = 'MI cluster maximum', s = pointsize,
                       marker = 'o', zorder=12)

        # a grid yields artifacts with very many clusters
        if grid:
            pyplot.grid(color='grey', zorder=0)


        if xticks:
            pyplot.xticks( harr, xlabels, rotation=270)
        else:
            pyplot.xticks( [])
        if xlimit is None:
            pyplot.xlim((-1, len(harr)))
        else:
            pyplot.xlim(xlimit)

        #print pyplot.xlim()

        if ymin is None: ymin=0
        pyplot.ylim((ymin, pyplot.ylim()[1]))

        pyplot.xlabel(xlabel)
        pyplot.ylabel('Mutual Information (bits)')

        if draw_legend:
            leg = pyplot.legend(fancybox=True, loc=0)
            leg.get_frame().set_alpha(0.6)
            leg.get_frame().set_facecolor('lightgrey')
            leg.set_zorder(13)


        pyplot.subplots_adjust(hspace=0.1, wspace=0.05, bottom=adj_bottom,
                               left=0.11, top=0.97, right=0.99)

        if filename is not None:
            pyplot.savefig( 'figures/%s' % filename, dpi=dpi)
            pyplot.close()
        else:
            pyplot.show()

        return xlabels, harr, maxmi, harr_mi, miarr


    def domain_domain_mi_heatmap(self,title_descr=''):

        # domains, sorted by number of instances
        domains = sorted(self.annotation_to_seq.keys(),
                         key=lambda a: len(self.annotation_to_seq[a]),
                         reverse=True)

        A = numpy.zeros((len(domains), len(domains)), dtype=numpy.double)

        for i, domain0 in enumerate(domains):
            for j, domain1 in enumerate(domains[:i+1]):
                mi = self.mutual_information_ann( annotations=(domain0, domain1))

                A[i,j] = mi

        pyplot.pcolormesh(A,
                      norm=matplotlib.colors.LogNorm(),
                      cmap = pyplot.cm.jet,
                      edgecolors='None')

        ax = pyplot.gca()
        ax.set_axis_bgcolor('black')
        
        pyplot.xlim((-1,len(domains)+1))
        pyplot.ylim((-1,len(domains)+1))

        pyplot.xticks( range(len(domains)),
                       ["%s (%d)" % (domain, len(self.annotation_to_seq[domain]))
                        for domain in domains],
                       rotation=270)
        
        pyplot.yticks( range(len(domains)),
                       ["%s" % domain for domain in domains])
        
        pyplot.title("Domain-Domain Mutual Information\n"+title_descr)
        
        pyplot.subplots_adjust(hspace=0.1, wspace=0.05, bottom=0.2,
                               left=0.11, top=0.95, right=0.98)
        cb = pyplot.colorbar()
        cb.set_label('Mutual Information (bits)')
        
        return A,domains

    def plot_histogram(self, hist_dict, title='',
                       filename=None,
                       xlabel='', ylabel='Count',
                       xlog=False, ylog=False,
                       color='blue',
                       xlimit=None, ylimit=None,
                       dpi=None):

        pyplot.rc('figure', figsize=(6,4))
        
        fig = pyplot.figure()
        ax = fig.add_subplot(111)

        #if xlog: ax.set_xscale('log')
        if ylog: ax.set_yscale('log')

        fig.suptitle('%s' % title)
        

        # bar positions
        #left = sorted(hist_dict.keys())
        # leave a bit of room between bars
        #width = [0.9] * len(left)
        #height = [hist_dict[n] for n in left]
        
        #pyplot.bar(left, height, width, color=color, edgecolor=color)

        keys = sorted(hist_dict.keys())
        weights = [hist_dict[n] for n in keys]

        if len(keys) < 40 and keys[-1]-keys[0] < 40:
            bins = keys + [keys[-1]+1]
        elif not xlog:
            #bins = 40
            bmin = keys[0]
            bmax = keys[-1]
            step = int((bmax-bmin) / 40)
            
            bins = range(keys[0], keys[-1]+step, step)

            #print keys
            #print bmin, bmax, step
            #print bins
        else:
            bins = 40
        
        (n, retbins, patches) = pyplot.hist( x=keys, weights=weights, color=color,
                                             bins = bins, log=xlog,
                                             align='left')

        ax.grid(color='gray', alpha=0.5)

        ax.set_xlabel(xlabel)
        ax.set_ylabel(ylabel)

        #print title
        #print "xticks:", retbins[n.nonzero()]
        if not xlog:
            ax.set_xticks(retbins[n.nonzero()])
        #ax.set_xticks([float(pos + 0.5) for pos in left])
        #ax.set_xticklabels(left)

        fig.subplots_adjust(bottom=0.1, left=0.1, top=0.9, right=0.98)

        if filename is not None:
            fig.savefig('figures/%s' % filename, dpi=dpi)
            pyplot.close()
        else:
            fig.show()

        return (n, retbins)
        
        

    def hist_clusters_per_domain(self, title_descr=' ', filename=None, **kwargs):
        hist = {}

        for ann in self.annotation_to_seq.keys():
            # number of instances of this annotation in each of the
            # clusters that contain it.  Here, we just count the
            # number of containing clusters
            joint_cnt = self.joint_cluster_ann_count( ann)
            n = len(joint_cnt)
            if n in hist:
                hist[n] += 1
            else:
                hist[n] = 1

        self.plot_histogram(hist,
                            filename=filename,
                            title='Clusters per domain\n'+title_descr,
                            xlabel='Count of clusters containing domain',
                            **kwargs)

        return


    def hist_domains_per_cluster(self, title_descr=' ', filename=None, **kwargs):
        hist = {}
        for cluster_id in self.cluster_to_seq.keys():
            ann_set = self.annotations_in_cluster(cluster_id)

            n = len(ann_set)
            if n in hist:
                hist[n] += 1
            else:
                hist[n] = 1

        self.plot_histogram(hist,
                            filename=filename,
                            title='Domains per cluster\n'+title_descr,
                            xlabel='Count of domains in cluster',
                            **kwargs
                            )
        return


    def hist_seqs_per_cluster(self, title_descr=' ', filename=None, **kwargs):

        hist = {}
        for (cluster_id, seq_ids) in self.cluster_to_seq.iteritems():
            n = len(seq_ids)
            if n in hist:
                hist[n] += 1
            else:
                hist[n] = 1

        self.plot_histogram(hist,
                            filename=filename,
                            title='Sequences per cluster\n'+title_descr,
                            xlabel='Cluster size', **kwargs)
        return
                            


    def hist_domains_per_seq(self, title_descr=' ', filename=None, **kwargs):
        hist = {}
        for (seq_id, ann_set) in self.seq_to_annotation.iteritems():
            n = len(ann_set)
            if n in hist:
                hist[n] += 1
            else:
                hist[n] = 1

        self.plot_histogram(hist,
                            filename=filename,
                            title='Domains per sequence\n'+title_descr,
                            xlabel='Count of domain instances in sequence',
                            **kwargs)
        return

    def hist_seqs_per_domain(self, title_descr=' ', filename=None, **kwargs):
        hist = {}
        for (ann, seq_ids) in self.annotation_to_seq.iteritems():
            n = len(seq_ids)
            if n in hist:
                hist[n] += 1
            else:
                hist[n] = 1

        self.plot_histogram(hist,
                            filename=filename,
                            title='Sequences per domain\n'+title_descr,
                            xlabel='Count of sequences containing domain',
                            **kwargs)
        return
    

def escape_latex(s):
    s = s.replace('_','\_')
    return s

if __name__ == "__main__":
    
    pyplot.ioff() # turn off figure updates
    pyplot.rc('font', size=10, family = 'serif',
              serif = ['Computer Modern Roman']
              )
    pyplot.rc('text', usetex=True)
    pyplot.rc('legend', fontsize = 10)

    pq = pfamq.pfamq()

    date = time.strftime('%Y%m%d')
    #date = "20110601"

    # Just the human and mouse families
    if True:
        set_id = 109 # ppod human, mouse
        family_set_name = 'ppod_20_cleanqfo2'
        #family_set_name = 'hgnc_apr10_selected'
        
        title_descr= "family_set_name: %s, set_id: %d" % (family_set_name,set_id)
        print title_descr
        
        fq = familyq.familyq(family_set_name=family_set_name)
        clustering = fq.fetch_families()[0]
        
        annotation, seq_map = pq.fetch_domain_map(
            set_id = set_id,
            seq_ids = tuple(reduce(
            lambda a,b:a+b, [list(clustering[fam]) for fam in clustering])))

        mi = mi_plot( clustering, annotation)

        pyplot.rc('figure', figsize=(6,4))
        mi.scatter_cluster_mi_specific(
            dpi=400,
            xlabel='Family',
            adj_bottom=0.28,
            # use pdf for transparency
            filename=date + '_scatter_cluster_mi_specific_' + family_set_name + '.pdf')

        fsuffix='_%s_set_id_%d.pdf' % (family_set_name, set_id)
        #mi.hist_seqs_per_cluster(
        #    filename=date + '_seqs_per_cluster' + fsuffix,
        #    title_descr=escape_latex(title_descr),
        #    xlog=True,
        #    dpi=200)
        #mi.hist_domains_per_cluster(
        #    filename=date + '_domains_per_cluster' + fsuffix,
        #    title_descr=escape_latex(title_descr),
        #    xlog=True,
        #    dpi=200)
        #mi.hist_clusters_per_domain(
        #    filename=date + '_clusters_per_domain' + fsuffix,
        #    title_descr=escape_latex(title_descr),
        #    dpi=200)
        #mi.hist_domains_per_seq(
        #    filename=date + '_domains_per_seq' + fsuffix,
        #    title_descr=escape_latex(title_descr),
        #    dpi=200)
        #mi.hist_seqs_per_domain(
        #    filename=date + '_seqs_per_domain' + fsuffix,
        #    title_descr=escape_latex(title_descr),
        #    ylog=True,
        #    dpi=200)

        # square, accounting for labels
        # pyplot.rc('figure', figsize=(4.5,4))
        #         mi.scatter_mi_entropy(
        #             dpi=150,
        #             filename=date + '_scatter_mi_entropy' + fsuffix,
        #             max_labels=33,
        #             label_exclude=(7,10,11,12,13,14,16,17,18,21,22,25,26,27,29,30,31,32))

    # Full mouse and human set
    if False:
        set_id = 109   # ppod human, mouse
        #set_id = 106   # all 48
        cr_id = 475    # clustering using the full 600k set.  Correct log calculation
        nc_score = 0.425 # near the optimal

        title_descr = "set_id: %(set_id)d, cr_id: %(cr_id)d, nc_score: %(nc_score)g" % locals()
        print title_descr


        
        clusclass = cluster_sql.hcluster( cluster_run_id = cr_id,
                                          cacheq=True)
        clustering = clusclass.cut_tree( distance = 1-nc_score,
                                         set_id_filter = set_id)
        annotation, seq_map = pq.fetch_domain_map(set_id = set_id)

        mi = mi_plot( clustering, annotation)

        if False:
            fsuffix='_cr_id_%d_nc_%g_set_id_%d_nolegend.png' % (cr_id, nc_score, set_id)
            pyplot.rc('figure', figsize=(6.5,3))
            mi.scatter_cluster_mi_specific(
                dpi=600, # png
                #dpi=150, # pdf
                adj_bottom=0.07,
                ymin=-0.01,
                grid=False,
                xticks=False,
                draw_legend=False,
                # 9996 clusters
                xlimit=(-100, 10100), # fit the points well inside the axis frame
                pointsize=0.5,
                # use pdf for transparency
                filename=date + '_scatter_cluster_mi_specific' + fsuffix)


        if False:
            fsuffix='_cr_id_%d_nc_%g_set_id_%d_zoomorig.pdf' % (cr_id, nc_score, set_id)
            pyplot.rc('figure', figsize=(6.5,4))
            mi.scatter_cluster_mi_specific(
                #dpi=600,
                dpi=150,
                adj_bottom=0.31,
                grid=True,
                pointsize=10,
                xlimit=(-0.5,44.5),
                ymin=-0.01,
                # use pdf for transparency
                filename=date + '_scatter_cluster_mi_specific' + fsuffix)

        #fsuffix='_cr_id_%d_nc_%g_set_id_%d.png' % (cr_id, nc_score, set_id)
        #mi.hist_seqs_per_cluster(
        #    filename=date + '_seqs_per_cluster' + fsuffix,
        #    title_descr=escape_latex(title_descr),
        #    ylog=True,xlog=True,
        #    dpi=300)
        #mi.hist_domains_per_cluster(
        #    filename=date + '_domains_per_cluster' + fsuffix,
        #    title_descr=escape_latex(title_descr),
        #    ylog=True, xlog=True,
        #    dpi=300)
        #mi.hist_clusters_per_domain(
        #    filename=date + '_clusters_per_domain' + fsuffix,
        #    title_descr=escape_latex(title_descr),
        #    ylog=True,
        #    dpi=300)
        #mi.hist_domains_per_seq(
        #    filename=date + '_domains_per_seq' + fsuffix,
        #   title_descr=escape_latex(title_descr),
        #    dpi=300)
        #mi.hist_seqs_per_domain(
        #    filename=date + '_seqs_per_domain' + fsuffix,
        #    title_descr=escape_latex(title_descr),
        #    ylog=True,xlog=True,
        #    dpi=300)

        fsuffix='_cr_id_%d_nc_%g_set_id_%d.png' % (cr_id, nc_score, set_id)
        pyplot.rc('figure', figsize=(4.5,4))
        mi.scatter_mi_entropy(
            dpi=600,
            filename=date + '_scatter_mi_entropy' + fsuffix,
            max_labels=20,
            label_exclude=(5,8,11,12,13,16,17,18))


    if False:
        clustering = {'x': (0,1,2,3,4,5,6,7),
                      'y': (8, 9, 10, 11, 12, 13)}
        annotation = {'a': (0,1,2,3),
                      'b': (4,5),
                      'c': (8, 9, 10, 11),
                      'd': (8, 9, 10, 11),
                      'e': (6,7, 12, 13)}

        mi = mi_plot( clustering, annotation)
        pyplot.rc('figure', figsize=(4.5,3))
        fsuffix='_mi_specific_example.pdf'

        mi.scatter_cluster_mi_specific(
            dpi=300,
            # use pdf for transparency
            filename=date + '_scatter_cluster' + fsuffix,
            adj_bottom = 0.18)

        #mi.hist_seqs_per_cluster(
        #    filename=date + '_seqs_per_cluster' + fsuffix,
        #    dpi=300)
        #mi.hist_domains_per_cluster(
        #    filename=date + '_domains_per_cluster' + fsuffix,
        #    dpi=300)
        #mi.hist_clusters_per_domain(
        #    filename=date + '_clusters_per_domain' + fsuffix,
        #    dpi=300)
        #mi.hist_domains_per_seq(
        #    filename=date + '_domains_per_seq' + fsuffix,
        #    dpi=300)
        #mi.hist_seqs_per_domain(
        #    filename=date + '_seqs_per_domain' + fsuffix,
        #    title_descr=escape_latex(title_descr),
        #    dpi=300)

    #A, domains = mi.domain_domain_mi_heatmap(title_descr=title_descr)


    # Full mouse and human set BLAST using 600k
    if False:
        set_id = 109 # ppod human, mouse
        cr_id = 474  # Clustering using e_value

        csql = cluster_sql.hcluster( cluster_run_id = cr_id,
                                     cacheq=True)

        clusclass = cluster_sql.hcluster( cluster_run_id = cr_id,
                                          cacheq=True)
        clustering = clusclass.cut_tree( distance = 1-nc_score,
                                         set_id_filter = set_id)
        annotation, seq_map = pq.fetch_domain_map(set_id = set_id)

        mi = mi_plot( clustering, annotation)

        e_range = [1E-200, 1E-100, 1E-50, 1E-25, 1E-10, 1E-5, 1E-3]

        #pyplot.rc('figure', figsize=(3.5,3.21))
        #mi.scatter_mi_entropy(
        #    dpi=400,
        #    filename=date + '_scatter_mi_entropy_cr_id_472_nc02_set_id_109.pdf',
        #    max_labels=5)
