#!/usr/bin/env python

# Jacob Joseph
# 28 Dec 2010

# Plot clustering entropy vs threshold for a (e.g., hierarchical)
# clustering

import numpy, time
from math import log
from DurandDB import familyq, pfamq
from JJcluster import cluster_sql
from information import mutualinfo
from matplotlib import pyplot

def threshold_range_entropy( nc_range, annotation, csql, set_id=None):

    entropy = {}
    mi_sum = {}
    for nc_score in nc_range:
        clustering = csql.cut_tree( distance = 1-nc_score,
                                    set_id_filter = set_id)
        miclass = mutualinfo( clustering, annotation)
        entropy[nc_score] = miclass.clustering_entropy()
        
        MI, posMI, negMI, cl_index, ann_index = miclass.mutual_information_specific_array()
        mi_sum[nc_score] = numpy.sum(MI)
        
    return entropy, mi_sum


def threshold_range_entropy_e_value( e_range, annotation, csql, set_id=None):

    entropy = {}
    mi_sum = {}
    for e_value in e_range:
        clustering = csql.cut_tree( distance = e_value,
                                    set_id_filter = set_id)
        miclass = mutualinfo( clustering, annotation)
        entropy[e_value] = miclass.clustering_entropy()
        
        MI, posMI, negMI, cl_index, ann_index = miclass.mutual_information_specific_array()
        mi_sum[e_value] = numpy.sum(MI)
        
    return entropy, mi_sum

if __name__ == "__main__":
    date = time.strftime('%Y%m%d')

    pq = pfamq.pfamq()
    
    
    if False:
        set_id = 109 # ppod human, mouse
        family_set_name = 'ppod_20_cleanqfo2'
        #family_set_name = 'hgnc_apr10_selected'

        print "family_set_name: %s, set_id: %d\n" % (family_set_name,set_id)
        
        fq = familyq.familyq(family_set_name=family_set_name)
        clustering = fq.fetch_families()[0]

        

    # Full mouse and human set NC using 600k
    if True:
        set_id = 109   # ppod human, mouse
        cr_id = 475    # clustering using the full 600k set
        #nc_score = 0.2 # near the optimal

        #print "set_id: %(set_id)d, cr_id: %(cr_id)d, nc_score: %(nc_score)d" % locals()

        csql = cluster_sql.hcluster( cluster_run_id = cr_id,
                                     cacheq=True)

        nc_range = [float(d)/100 for d in
                   range(10) + range(10,100,10) + range(90, 100)]
        #nc_range = [0.1, 0.2, 0.3, 0.4, 0.5]
        #nc_range = [0.6, 0.7, 0.8, 0.9]

        annotation, seq_map = pq.fetch_domain_map(set_id = set_id)
        
        entropy, mi_sum = threshold_range_entropy(
            nc_range, annotation, csql, set_id)

        x,y = zip(*sorted(entropy.items()))
        x,y2 = zip(*sorted(mi_sum.items()))

        pyplot.ioff() # turn off figure updates
        pyplot.rc('font', size=10, family = 'serif',
                  serif = ['Computer Modern Roman']
                  )
        pyplot.rc('text', usetex=True)
        pyplot.rc('legend', fontsize = 10)
        pyplot.rc('figure', figsize=(6,3))
        
        pyplot.plot(x, y, 'o-',
                    label='Entropy',
                    alpha=0.8,
                    markeredgewidth=0,
                    color='blue',
                    zorder=10)
        pyplot.plot(x, y2, 's-',
                    label='MI sum',
                    alpha=0.8,
                    markeredgewidth=0,
                    color='red',
                    zorder=11)
        pyplot.xlabel('NC score')
        pyplot.ylabel('(bits)')
        
        #pyplot.legend(['Entropy', 'MI sum'])
        leg = pyplot.legend(fancybox=True, loc=4)
        leg.get_frame().set_alpha(0.6)
        leg.get_frame().set_facecolor('lightgrey')
        leg.set_zorder(13)

        pyplot.grid(color='grey', zorder=0)
        
        pyplot.subplots_adjust(hspace=0.1, wspace=0.05, bottom=0.12,
                               left=0.07, top=0.97, right=0.98)
        
        pyplot.savefig('figures/' + date +
                       '_clustering_entropy_cr_id_%s_set_id_%s.pdf' % (cr_id, set_id))
        pyplot.close()
        

    if False:
        set_id = 109   # ppod human, mouse
        #cr_id = 381    # jaccard clustering, all
        cr_id = 383    # jaccard clustering, human, mouse

        print "jaccard clustering. set_id: %(set_id)d, cr_id %(cr_id)d" % locals()

        clusclass = cluster_sql.flatcluster( cluster_run_id = cr_id)
        clustering = clusclass.fetch_clusters()
        #annotation, seq_map = pq.fetch_domain_map(set_id = set_id)


    # Full mouse and human set BLAST using 600k
    if False:
        set_id = 109 # ppod human, mouse
        cr_id = 474  # Clustering using e_value

        csql = cluster_sql.hcluster( cluster_run_id = cr_id,
                                     cacheq=True)

        #e_range = [1E-200, 1E-100, 1E-50, 1E-25, 1E-10, 1E-5, 1E-3]
        e_range = [10**x for x in reversed(range(8))] + \
                  [10**(-5*x) for x in range(1,41)]

        annotation, seq_map = pq.fetch_domain_map(set_id = set_id)
        
        entropy, mi_sum = threshold_range_entropy_e_value(
            e_range, annotation, csql, set_id)

        x,y = zip(*sorted(entropy.items()))
        x,y2 = zip(*sorted(mi_sum.items()))
        pyplot.plot(x,y, 'o-', alpha=0.5)
        pyplot.plot(x,y2, 's-', alpha=0.5)
        pyplot.xlim(max(x), min(x))
        pyplot.xscale('log')
        pyplot.xlabel('E-value')
        pyplot.ylabel('(bits)')
        pyplot.legend(['Entropy', 'MI sum'])
        pyplot.grid(color='grey')
