#!/usr/bin/env python

# Jacob Joseph

from matplotlib import pyplot
from JJcluster.cluster_plot import plot
from DurandDB import familyq
import time, numpy

def build_family_sets(family_set_name):
    fq = familyq.familyq( family_set_name=family_set_name)

    family_lookup, family_members = fq.fetch_families()

    family_sets = {'ALL': ('ALL',),
                   'ALL-kin': ('ALL-kin',)}
    for family in family_lookup.keys():
        family_sets[family] = (family,)
    return family_sets

if __name__ == "__main__":

    family_set_name = 'ppod_20_cleanqfo2'
    set_id = 109

    family_sets = build_family_sets(family_set_name)
    #family_sets = {'ALL': ('ALL',),
    #               'ALL-kin': ('ALL-kin',),
    #               'ACSL': ('ACSL',),
    #               'ADAM': ('ADAM',),   
    #               'DVL': ('DVL',),  
    #               'FGF': ('FGF',),   
    #               'FOX': ('FOX',),   
    #               'GATA': ('GATA',),   
    #               'Kinase': ('Kinase',),
    #               'Kinesin': ('Kinesin',),
    #               'KIR': ('KIR',),
    #               'Laminin': ('Laminin',),
    #               'Myosin': ('Myosin',),
    #               'Notch': ('Notch',),
    #               'PDE': ('PDE',),
    #               'SEMA': ('SEMA',),   
    #               'Tbox': ('Tbox',),  
    #               'TNF': ('TNF',),  
    #               'TNFR': ('TNFR',),   
    #               'TRAF': ('TRAF',),  
    #               'USP': ('USP',),
    #               'WNT': ('WNT',)
    #               }

    # These ranges always specify distances, as stored in the tree in
    # the database

    e_range = (10, 1e-5, 1e-10, 1e-20, 1e-25, 1e-30, 1e-35, 1e-40, 1e-45, 1e-50,
               1e-55, 1e-60, 1e-65, 1e-70, 1e-75, 1e-80, 1e-85, 1e-90, 1e-95, 1e-100,
               1e-105, 1e-110, 1e-115, 1e-120, 1e-125, 1e-130, 1e-135, 1e-140, 1e-145,
               1e-150, 1e-155, 1e-160, 1e-165, 1e-170, 1e-175, 1e-180, 1e-185, 1e-190,
               1e-195, 1e-200, 0)

    # max bit_score 43735.9
    #bit_range = range( 0, 10000, 1000) + range(10000, 40000, 500) + range(40000, 44000, 100)

    # appropriate for the mouse and human tree.
    # bq.fetch_max_bit_score(102, symmetric=True, seq_id_0_set_id=109)
    # 43735.900000000001
    bit_range = ( #range(0, 40001, 5000) +
                  #range(40000, 43000, 200) +
                  #range(43000, 43400, 100) +
                  range(43400, 43500, 50) +
                  range(43500, 43690, 10) + 
                  range(43690, 43736, 5)
                  )

    # Appropriate for the complete tree
    # bq.fetch_max_bit_score(102, symmetric=True)
    # 45410.400000000001
    bit_range = ( #range(0, 40001, 5000) +
                  #range(40000, 43000, 200) +
                  #range(43000, 43400, 100) +
                  #range(43400, 43500, 50) +
                  #range(43500, 43690, 10) + 
                  #range(43690, 43736, 5)
                  range( 45100, 45300, 50) + 
                  range( 45300, 45350, 10) +
                  range( 45350, 45411, 5)
                  )

    #bit_range = range( 0, 44000, 5000) + [41000, 41500, 42000,
    #                                      42500, 43000, 43500, 44000]
    #bit_range = range( 20000, 44000, 5000)
    
    nc_range = [float(a) / 40 for a in range(1,40)]
    #nc_range = [float(a) / 10 for a in range(2,10)]

    xtick_formatters = {#'bit_score' : lambda d: "%s" % int(round(43735.9 - d)), # approximate
                        'bit_score' : lambda d: "%s" % int(round(45410.4 - d)), # approximate
                        'nc_score' : lambda d: "%4.3f" % (1.0 - d)
                        }

    cluster_variants = {
        #'single_linkage': {'e_value': {420: e_range},
        #                   'nc_score': {469: nc_range},
        #                   },
        # 'complete_linkage': {'e_value':  { 61: e_range},
        #                     'nc_score': { 60: nc_range }
        #                     },
        #'average_linkage': {#'nc_score_200k': {470: nc_range},  # full 200k set
                            #'nc_score_600k': {472: nc_range},  # full 600k set
                            #'nc_score_200k': {470: (0.150, 0.175, 0.2)},  # full 200k set
                            #'nc_score_600k_nolog': {472: nc_range},  # full 600k set
                            #'nc_score_600k': {475: nc_range},  # full 600k set
                            #'nc_score_600k_bit100': {476: nc_range},  # full 600k set, calculated from blast results above bit_score 100
                            #'nc_score_600k_bit60': {477: nc_range},  # full 600k set, calculated from blast results above bit_score 60
                            #'nc_score_600k_limit50': {481: nc_range},  # full 600k set, calculated from the best 50 blast hits
                            #'nc_score_600k_limit100': {482: nc_range},  # full 600k set, calculated from the best 100 blast hits
                            #'nc_score_600k_limit500': {483: nc_range},  # full 600k set, calculated from the best 500 blast hits
                            #'nc_score_hm': {113: nc_range},  # human, mouse
                            #'nc_score_hm': {113: (0.125, 0.150, 0.175)},
                            #'nc_score_200knotsymm': {76: nc_range},  # human, mouse
                            #'nc_score_200knotsymm': {76: (0.150, 0.175, 0.2)},
                            #'e_value': {384: e_range},
                            #},

        # These are clusterings that are restricted only to human and
        # mouse.  All are symmetric, and with the correct NC log
        # calculation
        
        #'complete_linkage': { 'bit_score': { 487: bit_range},
        #                      'nc_score': { 488: nc_range }
        #                      },
        #'average_linkage': {'bit_score':   { 486: bit_range},
        #                    'nc_score': { 485: nc_range }
        #                     },
        'average_linkage': {'bit_score': { 490: bit_range},
                            'nc_score':  { 475: nc_range }
                            },
        #'single_linkage': { 'bit_score':   { 489 : bit_range},
        #                    'nc_score':  { 484: nc_range }
        #                    },

        
        #'spici-g': { 'support=density, nc_score': dict( ((k,None) for k in (
        #122,  # 0.1
        #127,
        #116,
        #121,  # 0.4
        #147,
        #115,
        #125,  # 0.7
        #119,
        #117,  # 0.9
        #169,   # d=0.3, g=0.4
        #))),
        #             'support=density, bit_score': dict( ((k,None) for k in (
        #386, # 0.1
        #387,
        #388,
        #389,
        #390,
        #391,
        #393,
        #394,
        #395, # 0.9
        #))),
        #           },
        #'reference': { 'ppod': {383:None}},
        }


    stat_variants = {'summary': (('Precision',False), ('Recall',False), ('F',False)),
                     'f_only': ( ('F',False), ('F',True) ) }

    #p = plot( cluster_variants, stat_variants, family_sets, use_pickle=False,
    #          set_id=10,8
    #          set_id_filter = 109,       # human and mouse
    #          family_set_name='ppod_20'
    #          ) 

    p = plot( cluster_variants, stat_variants, family_sets,
              use_pickle=True,
              set_id = set_id,           # the set used to build the contingency table
              family_set_name=family_set_name
              ) 

    date = time.strftime('%Y%m%d')

    pyplot.ioff() # turn off figure updates
    pyplot.rc('font', size=10, family = 'serif',
              serif = ['Computer Modern Roman']
              )
    pyplot.rc('text', usetex=True)
    pyplot.rc('legend', fontsize = 10)

    if False and True:
        pyplot.rc('figure', figsize=(12,7), dpi=600)
        for family_set in family_sets:
            p.draw_plotset( plot_list = [ #('average_linkage', 'nc_777_full_comp_nonsym'),
                                          #('average_linkage', 'nc_777_hm_comp_nonsym'),
                                          #('average_linkage', 'nc_779_full_comp_sym'),
                                          #('average_linkage', 'nc_780_full_nocomp_nonsym'),

                                          ('spici', 'ncthresh'),
                                          #('single_linkage', 'e_value'),
                                          #('average_linkage', 'e_value'),
                                          #('complete_linkage', 'e_value'),
                                          #('mcl', 'e_value'),
                                          #('single_linkage', 'nc_score'),
                                          #('average_linkage', 'nc_score'),
                                          #('complete_linkage', 'nc_score'),
                                          #('mcl', 'nc_score'),
                                         ],
                            stats=stat_variants['summary'], family_set=family_set,
                            filename= date + '-%s-summary.png' % family_set,
                            dpi=100)

    if False:
        family_set='ALL'
        stype = 'bit_score'
        method = 'average_linkage'
        draw_legend = False
        pyplot.rc('figure', figsize=(6,2.5))
        subplots_adjust = {'hspace' : 0.05,
                           'wspace' : 0.05,
                           'bottom' : 0.22,
                           'left' : 0.04,
                           'top' : 0.98,
                           'right' : 0.99}
        #p.draw_plot( cl_method='average_linkage', stype='e_value',
        #             stats=stat_variants['summary'],
        #             family_set=family_set, draw_legend=True,
        #             standalone=True, dpi=150,
        #             filename = date + '_%s_average_e_value.eps' % family_set)
        p.draw_plot( cl_method=method, stype=stype,
                     stats=stat_variants['summary'],
                     family_set=family_set,
                     draw_legend = draw_legend,
                     standalone = True,
                     dpi = 150,
                     filename = date + '_%s_%s_%s%s_48genome.pdf' % (family_set, method, stype,
                                                             '_legend' if draw_legend else ''),
                     xtick_formatters = xtick_formatters,
                     subplots_adjust = subplots_adjust
                     )

    if False and True:
        print p.max_F_table()
        print p.max_F_table(adjust=True)

    if False and True:
        pyplot.rc('figure', figsize=(7,6))
        for (cr_id, thresh, ptype) in [ #(27, 0.05, 'identity'), (27, 0.05, 'promiscuity'),
                                      #(27, 0.4, 'identity'), (27, 0.4, 'promiscuity'),
                                      (27, 0.6, 'identity'), (27, 0.6, 'promiscuity'),
                                      (27, 0.8, 'identity'), (27, 0.8, 'promiscuity'),
                                      (27, 0.9, 'identity'), (27, 0.9, 'promiscuity'),
                                      #(27, 0.98, 'identity'), (27, 0.98, 'promiscuity'),
                                      (28, 1e-25, 'identity'), (28, 1e-25, 'promiscuity'),
                                      (28, 1e-60, 'identity'), (28, 1e-60, 'promiscuity'),
                                      (30, None, 'identity'), (30, None, 'promiscuity'),
                                      (31, None, 'identity'), (31, None, 'promiscuity')
                                      ]:
            for cl_method in p.cluster_variants:
                for stype in p.cluster_variants[cl_method]:
                    if cr_id in p.cluster_variants[cl_method][stype]:
                        plot_stype = stype
                        plot_cl_method = cl_method
                        break

            p.draw_pfam_scatter( cr_id, ptype=ptype, thresh=thresh,
                                 filename = date + '-%s-%s-%s-%s.png' % (plot_cl_method,
                                                                         plot_stype, thresh,
                                                                         ptype))

    # Comparison of F across families
    if False and True:
        pyplot.rc('figure', figsize=(16,9))

        p.draw_comp_plotset( plot_list = [('single_linkage', 'e_value'),
                                         ('complete_linkage', 'e_value'),
                                         ('mcl', 'e_value'),
                                         ('single_linkage', 'nc_score'),
                                         ('complete_linkage', 'nc_score'),
                                         ('mcl', 'nc_score')],
                         stats=stat_variants['f_only'],
                         family_keys = sorted(family_sets.keys()),
                         filename= date + '-F_comp.png',
                         dpi=100, draw_legend=True)

    # Bar charts
    if False and True:
        pyplot.rc('figure', figsize=(9,9))
        plot_list = [('single_linkage', 'e_value'),
                     ('single_linkage', 'nc_score'),
                     ('average_linkage', 'e_value'),
                     ('average_linkage', 'nc_score'),
                     ('complete_linkage', 'e_value'),
                     ('complete_linkage', 'nc_score'),
                     ('mcl', 'e_value'),
                     ('mcl', 'nc_score')
                     ]

        for (fprefix, family_keys) in ( ('sd', ('ACSL', 'FGF', 'FOX', 'Tbox', 'TNF', 'USP', 'WNT')),
                                        ('md_conserved', ('DVL', 'GATA', 'KIR', 'Notch', 'TRAF')),
                                        ('md_variable', ('ADAM', 'Kinase', 'Kinesin',
                                                         'Laminin', 'Myosin', 'PDE', 'SEMA', 'TNFR'))):
            p.draw_comp_barplot( plot_list = plot_list,
                                 stat_key = 'F',
                                 family_keys = family_keys,
                                 filename = date + 'bar_'+fprefix+'.png',
                                 dpi = 100, draw_legend = True)


    # Heatmap
    if True:
        pyplot.rc('figure', figsize=(9,6))
        plot_list = [#('single_linkage', 'e_value'),
                     #('average_linkage', 'e_value'),
                     #('spici-g', 'support=density, bit_score'),

                     #('single_linkage', 'nc_score'),
                     #('average_linkage', 'nc_score_hm'),
                     #('average_linkage', 'nc_score_200k'),
                     #('average_linkage', 'nc_score_600k_nolog'),
                     #('average_linkage', 'nc_score_600k'),
                     #('average_linkage', 'nc_score_600k_bit100'),
                     #('average_linkage', 'nc_score_600k_bit60'),
                     #('average_linkage', 'nc_score_600k_limit50'),
                     #('average_linkage', 'nc_score_600k_limit100'),
                     #('average_linkage', 'nc_score_600k_limit500'),
                     #('average_linkage', 'nc_score_200knotsymm'),
                     #('spici-g', 'support=density, nc_score'),
                     #('reference', 'ppod')

            ('average_linkage', 'bit_score'),
            #('single_linkage', 'nc_score'),
            #('average_linkage', 'nc_score'),
            #('complete_linkage', 'nc_score'),
            #('single_linkage', 'bit_score'),
            #('average_linkage', 'bit_score'),
            #('complete_linkage', 'bit_score'),
                     ]

        fname_tag = 'bit_48genome'
        #fname_tag = 'bit'
        subplots_adjust = {'hspace' : 0.05,
                           'wspace' : 0.05,
                           'bottom' : 0.1,
                           'left' : 0.07,
                           'top' : 0.96,
                           'right' : 0.99}
        # ('sd', ('ACSL', 'FGF', 'FOX', 'Tbox', 'TNF', 'USP', 'WNT')),
        # ('md_conserved', ('DVL', 'GATA', 'KIR', 'Notch', 'TRAF')),
        # ('md_variable', ('ADAM', 'Kinase', 'Kinesin',
        #                 'Laminin', 'Myosin', 'PDE', 'SEMA', 'TNFR'))
        family_set = ('all', ('ACSL', 'FGF', 'FOX', 'Tbox', 'TNF', 'WNT',
                              'DVL', 'GATA', 'KIR', 'Notch', 'TRAF', 'USP',
                              'ADAM', 'Kinase', 'Kinesin',
                              'Laminin', 'Myosin', 'PDE', 'SEMA', 'TNFR',
                              'ALL', 'ALL-kin')),

        for stat_key in ("Recall", "Precision", "F"):
            for (fprefix, family_keys) in family_set:
                p.draw_heatmap_set( plot_list = plot_list,
                                    stat_key = stat_key,
                                    family_keys = family_keys,
                                    filename = date + '_heatmap_' + fname_tag + '_' +stat_key+'_'+fprefix+'.pdf',
                                    #filename = date + 'heatmap_bestspici_'+stat_key+'_'+fprefix+'.eps',
                                    #filename = None,
                                    dpi=100,
                                    num_rows=1,
                                    num_cols=3,
                                    draw_colorbar=False,
                                    subplots_adjust=subplots_adjust,
                                    xtick_formatters = xtick_formatters)
