#!/usr/bin/env python

# Jacob Joseph
# 10 August 2010

# Plot comparisons of distributions of scores within and outside of
# clusters or families.

import sys, time, copy
from DurandDB import familyq
from JJcluster import cluster_sql
from matplotlib import pyplot


def sum_histograms(a, b):
    c = copy.copy(a)
    for key, (kmin, kmax, kcnt) in b.items():
        if not key in c:
            c[key] = (kmin, kmax, kcnt)
        else:
            c[key] = (min(c[key][0], kmin),
                      max(c[key][1], kmax),
                      c[key][2]+kcnt)
    return c

class compare_distr(object):
    cluster_cut = None

    def __init__(self, br_id=None, nc_id=None, stype=None,
                 symmetric=True, set_id_filter_annotated = None,
                 set_id_filter_all = None,
                 family_set_name = None,
                 cr_id = None):
        self.br_id = br_id
        self.nc_id = nc_id
        self.stype = stype
        self.set_id_filter_annotated = set_id_filter_annotated
        self.set_id_filter_all = set_id_filter_all
        self.symmetric = symmetric

        self.fq = familyq.familyq( family_set_name = family_set_name, cacheq=True)

        self.seq_set_annotated = self.fq.fetch_seq_set( self.set_id_filter_annotated)
        self.seq_set_all = self.fq.fetch_seq_set( self.set_id_filter_all)

        self.csql = cluster_sql.hcluster( cluster_run_id = cr_id)
        #self.cluster_cut = self.csql.cut_tree( 0.175, self.set_id_filter)

    def build_histograms( self, seq_ids_list):
        """Sum the histograms over a list of seq_id sets.  This is
        useful for, for example, counting all of the hits within a
        number of families, but not between them.  Returns the inner
        histogram (within sets), the outer to self.set_id_filter_all,
        and the outer, but only self.set_id_filter_annotated.  (The
        latter two can be the same.)
        """

        hist_dict_inner = {}
        hist_dict_outer_all = {}  # should subtract mh, or just overlay mh on the plot?
        hist_dict_outer_ann = {}

        for seq_ids in seq_ids_list:
            seq_set_inner = set(seq_ids)
            #seq_set_outer_all = self.seq_set_all - seq_set_inner
            # include known family-family hits in the histogram of all outer pairs
            #seq_set_outer_all = self.seq_set_all
            # Don't include any mouse and human sequences in the outer set
            seq_set_outer_all = self.seq_set_all - self.seq_set_annotated
            seq_set_outer_ann = self.seq_set_annotated - seq_set_inner

            inner = self.fq.fetch_score_histogram(
                stype = self.stype, br_id = self.br_id, nc_id = self.nc_id,
                set_id_filter = self.set_id_filter_annotated,
                symmetric = self.symmetric,
                query_seq_id = seq_set_inner,
                seq_set = seq_set_inner,
                return_key_dict = True)
            
            outer_all = self.fq.fetch_score_histogram(
                stype = self.stype, br_id = self.br_id, nc_id = self.nc_id,
                set_id_filter = self.set_id_filter_all,
                symmetric = self.symmetric,
                query_seq_id = seq_set_inner,
                seq_set = seq_set_outer_all,
                return_key_dict = True)

            outer_ann = self.fq.fetch_score_histogram(
                stype = self.stype, br_id = self.br_id, nc_id = self.nc_id,
                set_id_filter = self.set_id_filter_annotated,
                symmetric = self.symmetric,
                query_seq_id = seq_set_inner,
                seq_set = seq_set_outer_ann,
                return_key_dict = True)

            hist_dict_inner = sum_histograms( hist_dict_inner, inner)
            hist_dict_outer_all = sum_histograms( hist_dict_outer_all, outer_all)
            hist_dict_outer_ann = sum_histograms( hist_dict_outer_ann, outer_ann)

        return (hist_dict_inner.values(),
                hist_dict_outer_ann.values(),
                hist_dict_outer_all.values(),
                )
        

    def build_barplot(self, inner_axes, outer_axes, seq_ids_list, ylog=False,
                      draw_legend=True, omit_outer_all=False):
        histograms = self.build_histograms(seq_ids_list)

        plots = []

        colors = ('blue', 'red', 'green')
        labels = ('M,H: homologs', 'M,H: not homologs', 'Unannotated genomes')

        for i, (hist, color, alpha, zorder, axes) in enumerate(zip( 
            histograms,
            colors,
            (.6, .6, 0.6),
            (2,1,0),
            (inner_axes, outer_axes, outer_axes))):

            if omit_outer_all and i == 2: continue
            
            left = [ a[0] for a in hist ]
            width = [ a[1] - a[0] for a in hist ]
            height = [ a[2] for a in hist ]
            #print hist
            #print left
            #print height
            #print width
            plots.append( axes.bar( left, height, width,
                                    linewidth=0,
                                    color = color, edgecolor = color, 
                                    log = ylog,
                                    alpha = alpha,
                                    zorder = zorder))

        # FIXME: legend drawing fails in matplotlib 1.1.0 when a plot has no data
        if draw_legend:
            if omit_outer_all:
                l = pyplot.legend(plots[:2], labels[:2])
            else:
                l = pyplot.legend(plots, labels)

            l.legendPatch.set(fill=True, facecolor='gray', edgecolor='gray', alpha=0.5)
        return

    def compare_inner_outer(self, family_name = None,
                            cluster_id = None,
                            seq_ids = None,
                            ylog = False, xlog = False,
                            fprefix = None, fext = 'png',
                            draw_legend = True,
                            figsize = None,
                            omit_outer_all = False,
                            plot_adjust = None):

        
        pyplot.ioff() # turn off figure updates
        #pyplot.rc('figure', figsize=(8,5))
        #pyplot.rc('figure', figsize=(3.25,2.25))
        if figsize is None:
            figsize = (3.25,2.25)

        pyplot.rc('figure', figsize=figsize)
            
        pyplot.rc('font', size=10, family = 'serif',
                  serif = ['Computer Modern Roman']
                  )
        pyplot.rc('legend', fontsize = 10)
        pyplot.rc('text', usetex=True)

        fig = pyplot.figure()        
        ax1 = fig.add_subplot(111)
        if xlog: ax1.set_xscale('log')
        if ylog: ax1.set_yscale('log')
        #ax1.set_clip_on(False)
        
        if seq_ids is None:
            if family_name is not None:
                #ax1.set_title(family_name)

                # this isn't quite correct.  It counts hits between
                # annotated families as "in family"
                
                # 21 May 2012 - Seems correct with the summing, above.
                
                if family_name == 'ALL':
                    seq_ids_list = self.fq.family_sets.values()
                elif family_name == 'ALL-kin':
                    families = [k for k in self.fq.family_sets.keys() if k != 'Kinase']
                    print "families:", families
                    seq_ids_list = []
                    for fam in families:
                        seq_ids_list.append( self.fq.fetch_family_seqs(fam))
                else:
                    seq_ids_list = [self.fq.fetch_family_seqs( family_name)]

            elif cluster_id is not None:
                seq_ids_list = [self.cluster_cut[cluster_id]]
                ax1.set_title(cluster_id)
            else:
                assert False, "Must specify a cluster or family_name"
            
        #ax1.set_xlim((0,1))
        #ax1.set_autoscale_on(True)
        if self.stype=='nc_score':
            ax1.set_xlabel("NC")
        elif self.stype=='bit_score':
            ax1.set_xlabel("BIT score")
        elif self.stype=='e_value':
            ax1.set_xlabel("E-value")

        ax1.grid(color='gray', alpha=0.5)

        ax2 = ax1.twinx()
        #ax2.set_clip_on(False)
        if xlog: ax2.set_xscale('log')

        ax1.set_ylabel('Other pairs') #, color='red')
        ax2.set_ylabel('Family Pairs') #, color='blue')

        self.build_barplot(ax2, ax1, seq_ids_list, ylog = ylog,
                           draw_legend = draw_legend,
                           omit_outer_all = omit_outer_all)

        # I'm not sure why this doesn't happen automatically.  It
        # seems the patches don't update the data limits in axes.py
        #if self.stype != 'nc_score':
        #    ax1.set_xlim( min(hist1[0][0], hist2[0][0]), max(hist1[-1][1], hist2[-1][1]))
        #    ax1.set_ylim( 0, max( a[2] for a in hist1))
        #    ax2.set_ylim( 0, max( a[2] for a in hist2))
        #else:
        #    ax1.set_xlim( 0, 1)

        if self.stype == 'bit_score' and ylog:
            if family_name == 'ALL':
                ax1.set_ylim(0.9, 10**6)
                ax2.set_ylim(0.9, 10**6)
                #ax1.set_xlim(30, 10**4)
                #ax2.set_xlim(30, 10**4)
            elif family_name == 'ALL-kin':
                ax1.set_ylim(0.9, 10**6)
                ax2.set_ylim(0.9, 10**6)

                #print ax1.get_xticks()

        if self.stype == 'e_value':
            ax1.invert_xaxis()
            ax2.invert_xaxis()

        if plot_adjust is not None:
            pyplot.subplots_adjust(**plot_adjust)
            #pyplot.subplots_adjust(left=0.21, right=0.81, top=0.97, bottom=0.17)
            
        #pyplot.tight_layout(pad=0)

        if fprefix is not None:
            date = time.strftime('%Y%m%d')
            fname = date + '_' + fprefix + '_%s_%s_%s_%s_%s_%s_%s.%s' % (
                self.stype, self.br_id, self.nc_id,
                family_name,
                self.set_id_filter_all,
                self.set_id_filter_annotated,
                self.symmetric,
                fext)


            pyplot.savefig( 'figures/%s' % fname)
            pyplot.close()
        else:
            pyplot.show()

        return (fig, ax1, ax2)


if __name__ == "__main__":

    # 200k
    br_id = 102   # all 48 genomes
    nc_id = 808
    #stype = 'nc_score'
    stype = sys.argv[1]
    symmetric = True
    set_id_filter_all = 108       # 12 genomes (200k)
    #set_id_filter_all = 109
    set_id_filter_annotated = 109 # human, mouse
    family_set_name = "ppod_20_cleanqfo2"
    #fprefix = 'legend'
    #fprefix = 'legend_mhonly'
    fprefix = 'mhonly'
    figsize = (3.25, 2)
    plot_adjust = {'left' : 0.22, 'right': 0.79, 'top': 0.97, 'bottom': 0.17}
    #figsize = (6, 3)
    #plot_adjust = {'left' : 0.21, 'right': 0.80,
    #               'top': 0.97, 'bottom': 0.17}


    cr_id = 475 # average_linkage: NC, with correct log score transformation

    #fprefix = sys.argv[3]'' if len(sys.argv) == 4 else None

    cd = compare_distr( br_id = br_id, nc_id = nc_id, stype = stype,
                        symmetric = symmetric,
                        set_id_filter_annotated = set_id_filter_annotated,
                        set_id_filter_all = set_id_filter_all,
                        cr_id = cr_id, family_set_name = family_set_name)

    #clusters = cd.csql.cut_tree( 0.175, set_id_filter)

    # acsl_seqs = cd.fq.fetch_family_seqs( 'ACSL')

    #fig,ax1,ax2 = cd.compare_inner_outer( family_name ='ADAM')

    cd.compare_inner_outer( family_name = sys.argv[2],
                            fprefix = fprefix,
                            xlog = (stype in ('bit_score','e_value')),
                            #xlog = False,
                            #ylog = (stype=='bit_score'),
                            #ylog = True,
                            fext = 'pdf',
                            draw_legend = False if sys.argv[3]=='False' else True,
                            figsize = figsize,
                            omit_outer_all = True,
                            plot_adjust = plot_adjust)


    

    
    
    
    
