#!/usr/bin/env python

# Jacob Joseph
# 12 July 2008

# Plot cluster statistics

import cluster_stat
from matplotlib import pyplot
import cPickle, time, numpy, os

from matplotlib import scale as mscale
from JJutil.mpl_customscale import ExpScale



class plot(object):

    stat_dict = None
    clusterings = None
    
    def __init__(self, cluster_variants, stat_variants, family_sets, set_id=None,
                 use_pickle=False, set_id_filter=None, family_set_name=None):
        self.cluster_variants = cluster_variants
        self.stat_variants = stat_variants
        self.family_sets = family_sets
        self.set_id = set_id
        self.set_id_filter = set_id_filter
        self.family_set_name = family_set_name

        fname = os.path.expandvars('$HOME/tmp/pickles/cluster_plot_statdict_%d_%s.pickle' % (self.set_id,
                                                                                             self.family_set_name))
        if use_pickle and os.path.exists(fname):
            f = open(fname,'r')
            self.stat_dict = cPickle.load(f)
            f.close()
        else:
            self.build_stat_dict()
            f = open(fname,'w')
            cPickle.dump(self.stat_dict, f)
            f.close()

        mscale.register_scale(ExpScale)

    def build_stat_dict(self):
        self.stat_dict = {}
        sd = self.stat_dict

        # keep all clustering classes around
        self.clusterings = {}
        c = self.clusterings

        print "***Building plot statistics dictionary"""
        print "  ", time.strftime('%H:%M:%S (%d %h %Y)')

        # Initialize the statistics dictionary
        for family_set in self.family_sets:
            sd[family_set] = {}
            for cl_method in self.cluster_variants:
                sd[family_set][cl_method] = {}
                for stype in self.cluster_variants[cl_method]:
                    sd[family_set][cl_method][stype] = {}
                    
                    
        for cl_method in self.cluster_variants:
            for stype in self.cluster_variants[cl_method]:
                print "  ", cl_method, stype
                for (cr_id, params) in self.cluster_variants[cl_method][stype].items():

                    # Flat clusters, MCL
                    if cl_method=='mcl':
                        if not c.has_key(cr_id):
                            c[cr_id] = cluster_stat.flatclust_stat( cr_id, set_id = self.set_id,
                                                                    family_set_name=self.family_set_name)

                        assert False, "params are now stored as a list. fix this code"
                        inflation = c[cr_id].csql.cluster_params
                        inflation = eval(inflation)['main_inflation']
                        
                        for adjust in (False, True):
                            c[cr_id].set_randomize(adjust)

                            for family_set in sorted(self.family_sets):
                                if not sd[family_set][cl_method][stype].has_key(inflation):
                                    sd[family_set][cl_method][stype][inflation] = {}

                                sd[family_set][cl_method][stype][inflation][adjust] = c[cr_id].get_overall_stats(
                                    family_set = self.family_sets[family_set])

                    # Flat clusters, SPICi
                    elif cl_method in ('spici',):
                        if not c.has_key(cr_id):
                            c[cr_id] = cluster_stat.flatclust_stat( cr_id, set_id = self.set_id,
                                                                    family_set_name=self.family_set_name)
                        # for adjust in (False, True): # don't use with a set_filter
                        for adjust in (False,):
                            c[cr_id].set_randomize(adjust)

                            score_threshold = c[cr_id].csql.score_threshold()
                            if score_threshold is None: score_threshold = 0
                            
                            for family_set in sorted(self.family_sets):
                                if not sd[family_set][cl_method][stype].has_key(score_threshold):
                                    sd[family_set][cl_method][stype][score_threshold] = {}

                                sd[family_set][cl_method][stype][score_threshold][adjust] = c[cr_id].get_overall_stats(
                                    family_set = self.family_sets[family_set])

                    # Flat clusters, SPICi
                    elif cl_method in ('spici-d',):
                        if not c.has_key(cr_id):
                            c[cr_id] = cluster_stat.flatclust_stat( cr_id, set_id = self.set_id,
                                                                    family_set_name=self.family_set_name)
                        # for adjust in (False, True): # don't use with a set_filter
                        for adjust in (False,):
                            c[cr_id].set_randomize(adjust)

                            params = eval(c[cr_id].csql.cluster_params)  # a list
                            param_dict = dict( zip(params[::2], params[1::2]))
                            #print param_dict
                            param = float(param_dict['-d']) if '-d' in param_dict else 0.5
                            #print param, type(param)
                            
                            #if score_threshold is None: score_threshold = 0
                            
                            for family_set in sorted(self.family_sets):
                                if not sd[family_set][cl_method][stype].has_key(param):
                                    sd[family_set][cl_method][stype][param] = {}

                                sd[family_set][cl_method][stype][param][adjust] = c[cr_id].get_overall_stats(
                                    family_set = self.family_sets[family_set])

                    # Flat clusters, SPICi
                    elif cl_method in ('spici-g',):
                        if not c.has_key(cr_id):
                            c[cr_id] = cluster_stat.flatclust_stat( cr_id, set_id = self.set_id,
                                                                    family_set_name=self.family_set_name)
                        # for adjust in (False, True): # don't use with a set_filter
                        for adjust in (False,):
                            c[cr_id].set_randomize(adjust)

                            params = eval(c[cr_id].csql.cluster_params)  # a list
                            param_dict = dict( zip(params[::2], params[1::2]))
                            #print param_dict
                            param = float(param_dict['-g']) if '-g' in param_dict else 0.5
                            #print param, type(param)
                            
                            #if score_threshold is None: score_threshold = 0
                            
                            for family_set in sorted(self.family_sets):
                                if not sd[family_set][cl_method][stype].has_key(param):
                                    sd[family_set][cl_method][stype][param] = {}

                                sd[family_set][cl_method][stype][param][adjust] = c[cr_id].get_overall_stats(
                                    family_set = self.family_sets[family_set])

                    # reference clustering.  No parameters (e.g., PPOD)
                    elif cl_method in ('reference',):
                        if not c.has_key(cr_id):
                            c[cr_id] = cluster_stat.flatclust_stat( cr_id, set_id = self.set_id,
                                                                    family_set_name=self.family_set_name)

                        for family_set in sorted(self.family_sets):
                            if not sd[family_set][cl_method][stype].has_key(None):
                                sd[family_set][cl_method][stype][None] = {}

                            sd[family_set][cl_method][stype][None][False] = c[cr_id].get_overall_stats(
                                family_set = self.family_sets[family_set])

                    # Hierarchical clusters
                    else:
                        if not c.has_key(cr_id):
                            c[cr_id] = cluster_stat.hclust_stat( cr_id, set_id=self.set_id,
                                                                 family_set_name=self.family_set_name)
                        for dist in params:
                            #if stype=='nc_score' or stype[:3]=='nc_': dist = 1 - thresh
                            #else: dist = thresh
                            
                            print "fetching cut (distance)", cr_id, dist
                            c[cr_id].fetch_cut( dist)
                            
                            #for adjust in (False, True): # don't use with a set_filter
                            for adjust in (False,):
                                #print "ADJUST:", adjust
                                c[cr_id].set_randomize(adjust)
                                
                                for family_set in sorted(self.family_sets):
                                    if not sd[family_set][cl_method][stype].has_key(dist):
                                        sd[family_set][cl_method][stype][dist] = {}
                                    
                                    sd[family_set][cl_method][stype][dist][adjust] = c[cr_id].get_overall_stats(
                                        family_set = self.family_sets[family_set])

        print "  ", time.strftime('%H:%M:%S (%d %h %Y)')

                      
    def draw_plotset( self, plot_list, stats, family_set, symbols=None,
                      filename=None, dpi=100, draw_legend=True,
                      xtick_formatters=None):
        
        pyplot.suptitle('Family set: %s' % family_set)

        for (i, (cl_method, stype)) in enumerate( plot_list):
            pyplot.subplot(2,3,i+1)
            self.draw_plot( cl_method, stype, stats, family_set, symbols=symbols,
                            draw_legend = draw_legend,
                            xtick_formatters = xtick_formatters)

        #pyplot.subplots_adjust(left=0.1, right=1, top=1, bottom=0.13)

        pyplot.subplots_adjust(hspace=0.3, bottom=0.05, left=0.02, top=0.92, right=0.98)
        #pyplot.ion()  # enable figure updates

        if filename is not None:
            pyplot.savefig( 'figures/%s' % filename, dpi=dpi)
            pyplot.close()
        else:
            pyplot.show()


    def draw_plot( self, cl_method, stype, stats, family_set, symbols=None,
                   draw_legend=True, standalone=False, filename=None, dpi=None,
                   xlog=False, xtick_formatters=None,
                   subplots_adjust=None):

        if symbols is None:
            symbols = ['b', 'g', 'r', 'c', 'y', 'm', 'k']

        # These are always distances
        params = sorted(self.stat_dict[family_set][cl_method][stype].keys())


        for (i, (stat, adjust)) in enumerate(stats):
            series = []
            for param in params:
                series.append( self.stat_dict[ family_set][ cl_method][ stype][ param][adjust][ stat])

            if standalone:
                label=stat
            else:
                label = '%s_%s' % (stat, adjust)
            pyplot.plot( params, series, symbols[i],
                         label = label,
                         alpha = 0.7)

        # Reverse the axis for any linkage method.  These are distances.
        if 'linkage' in cl_method:
            pyplot.xlim( pyplot.xlim()[::-1])


        #if cl_method in ('single_linkage','average_linkage','complete_linkage') and stype=='e_value':
        #    # reversed log axis
        #    pyplot.xscale('log', basex=10)
        #    (left, right) = pyplot.xlim()     
        #    pyplot.xlim(right, left)
        #    #print cl_method, stype, left, right, pyplot.xlim()

        if cl_method=='mcl':
            title = "MCL (%s)" % stype
            xlabel = "Inflation"
        else:
            title = cl_method
            xlabel=stype

        if stype=='nc_score':
            #pyplot.xlim(-0.01, 1.01)
            #pyplot.xticks( [float(a)/10 for a in range(0,11,1)] )
            xlabel="NC score"
        elif stype=='e_value':
            xlabel="E-value"
        elif stype == 'bit_score':
            xlabel = "BIT score"

        if not standalone:
            pyplot.title( title.replace('_', ' '))

        pyplot.xlabel( xlabel.replace('_', ' '))

        if xtick_formatters is not None and stype in xtick_formatters:
            params_str = [ xtick_formatters[stype](param) for param in params]
        else:
            params_str = [r'$%s$' % param for param in params]

        nticks = 20
        step = int(numpy.ceil(float(len(params)) / nticks))
        indices = range(0, len(params), step)

        pyplot.xticks( numpy.array(params)[indices], numpy.array(params_str)[indices],
                       rotation='vertical')

        pyplot.axis('tight')

        #pyplot.axis([0, len(params)+6, 0, 1.05])
        pyplot.yticks( [ float(a)/10 for a in range(0,11,1)] )
        pyplot.grid(color='gray', alpha=0.5)
        pyplot.ylim( -0.01, 1.01)

        if draw_legend:
            l = pyplot.legend(pad=0, loc='best')
            l.legendPatch.set(fill=True, facecolor='gray', edgecolor='gray', alpha=0.5)

        if standalone and filename is not None:
            if subplots_adjust is None:
                pyplot.subplots_adjust(left=0.05, right=0.95, top=0.95, bottom=0.1)
            else:
                pyplot.subplots_adjust(**subplots_adjust)

            pyplot.savefig( 'figures/%s' % filename, dpi=dpi)
            pyplot.close()
        else:
            pyplot.show()
        return

    def calc_max_stat( self, stat_key, family_keys):
        # { family_key: {cl_method: {stype: {adjust: (value, statistic) }}}}
        stat_dict_max = {}
        
        sd = self.stat_dict
        for family_set in family_keys:
            stat_dict_max[ family_set] = {}
            for cl_method in sd[ family_set]:
                stat_dict_max[ family_set][ cl_method] = {}
                for stype in sd[ family_set][ cl_method]:
                    stat_dict_max[ family_set][ cl_method][ stype] = None
                    for (position, stats) in sd[ family_set][ cl_method][ stype].items():
                        max_tup = stat_dict_max[ family_set][ cl_method][ stype]

                        if max_tup is None or max_tup[1] < stats[ stat_key]:
                            stat_dict_max[ family_set][ cl_method][ stype] = (position, stats[stat_key])

        
        return stat_dict_max
    
    def draw_comp_barplot( self, plot_list, stat_key, family_keys, filename = None,
                           dpi = 100, draw_legend = True, symbols = None):
        if symbols is None:
            symbols = ['b', 'g', 'r', 'c', 'y', 'm', 'k', '#FF00FF']

        pyplot.suptitle('Maximum F')

        stat_dict_max = self.calc_max_stat( stat_key, family_keys)

        # bar width
        bar_width = 1.0 / ( len(plot_list) + 1)

        # group positions (left), and axis labels
        g_pos = range( len(family_keys))
        label_offset = bar_width * len(plot_list) / 2
        pyplot.yticks( [a+label_offset for a in g_pos], family_keys)
        pyplot.ylim( (0, g_pos[-1]+1))
        pyplot.ylim( (g_pos[-1]+1, 0))

        pyplot.xticks( [float(a)/10 for a in range(10)]+[1])
        pyplot.xlim( (0, 1.05))
        pyplot.grid(color='gray', alpha=0.5)
        
        for (i, family_key) in enumerate(family_keys):
            for (j, (cl_method, stype)) in enumerate(plot_list):
                pos = g_pos[i] + j * bar_width
                (val, statistic) = stat_dict_max[ family_key][ cl_method][ stype]
                pyplot.barh(pos, statistic, bar_width,
                            color=symbols[j],
                            label=cl_method+', '+stype)
                
            if i==0 and draw_legend:
                l = pyplot.legend(pad=0)
                # l.legendPatch.set_alpha(0.5)
                l.legendPatch.set(fill=True, facecolor='gray', edgecolor='gray', alpha=0.7)

        if filename is not None:
            pyplot.savefig( 'figures/%s' % filename, dpi=dpi)
            pyplot.close()
        else:
            pyplot.show()

        return

    def draw_heatmap_set( self, plot_list, stat_key, family_keys,
                          filename = None, dpi=100, num_cols=None, num_rows=None,
                          adjust=False, draw_colorbar=True,
                          subplots_adjust=None,
                          xtick_formatters=None):
        if filename is None:
            pyplot.ion() # turn on figure updates

        #pyplot.suptitle('F heatmap')

        if num_rows is None: num_rows = 1
        if num_cols is None: num_cols = (len(plot_list)+1) / num_rows

        # color scale should always extend 0-1
        # Actually, stop around 1.3 to provide a brighter max, darker min
        norm = pyplot.Normalize(vmin=0, vmax=1)
            
        for (i, (cl_method, stype)) in enumerate( plot_list):
            ax = pyplot.subplot(num_rows, num_cols, i+1)

            self.draw_heatmap( ax, family_keys, cl_method, stype, stat_key,
                               label_y = i % num_cols == 0, adjust=adjust,
                               norm=norm,
                               xtick_formatters=xtick_formatters)

        if subplots_adjust is None:
            pyplot.subplots_adjust(hspace=0.15, wspace=0.1, bottom=0.05,
                                   left=0.06, top=0.95, right=0.98)
        else:
            pyplot.subplots_adjust(**subplots_adjust)
            
        if draw_colorbar:
            cb = pyplot.colorbar( ticks=[float(a)/10 for a in range(11)],
                                  boundaries=[float(a)/500 for a in range(501)],
                                  cmap=pyplot.cm.jet,
                                  norm=norm)
            cb.set_label(stat_key)

        if filename is not None:
            pyplot.savefig( 'figures/%s' % filename, dpi=dpi)
            pyplot.close()
        #else:
            #pyplot.show()

    def draw_heatmap( self, ax, family_keys, cl_method, stype, stat_key, label_y=True,
                      adjust=False, norm=None, xtick_formatters=None):

        if cl_method=='mcl':
            title = "MCL (%s)" % stype
            xlabel = "Inflation"
        else:
            title = cl_method.replace('_', ' ')

            if stype=='nc_score':
                xlabel = "NC score"
            elif stype=='bit_score':
                xlabel = "BIT score"
            else:
                xlabel = stype.replace('_', ' ')

        pyplot.title( title)
        pyplot.xlabel( xlabel)
        pyplot.axis('tight')

        # group positions
        g_pos = range( len(family_keys))
        bar_width = 1.0 / len(family_keys)
        label_offset = bar_width * len(family_keys) / 2

        if label_y:
            pyplot.yticks( [a+label_offset for a in g_pos], family_keys)
        else:
            pyplot.yticks( [], [])

        pyplot.ylim( (0, g_pos[-1]+1))
        pyplot.ylim( (g_pos[-1]+1, 0))

        # parameters (x-axis) is shared among all families
        params = sorted(self.stat_dict[ family_keys[0]][ cl_method][ stype].keys())

        # Reverse the axis for any linkage method.  These are distances.
        if 'linkage' in cl_method:
            params.reverse()

        matrix = []
        for (i, family_key) in enumerate(family_keys):
            vals = [self.stat_dict[ family_key][ cl_method][ stype][ param][adjust][ stat_key]
                    for param in params]
            matrix.append( vals)

        matrix = numpy.array( matrix)
        pyplot.pcolor(matrix, norm=norm, cmap=pyplot.cm.jet)

        print params
        if 'linkage' in cl_method and stype=='e_value':
            # reversed axis
            #pyplot.xlim( (len(params), params[0]))

            params_str = []
            for param in params:
                format = '%.2g' % param
                if 'e' in format:
                    (base,exp) = format.split('e')
                    exp_sign = exp[0]
                    exp = int(exp[1:])
                    base = float(base)
                    params_str.append( r'$%g^{%s%d}$' % (base, exp_sign, exp))
                else:
                    params_str.append( r'$%g$' % float(format))

            # reverse the x axis
            #pyplot.xlim( pyplot.xlim()[::-1])

        else:
            if xtick_formatters is not None and stype in xtick_formatters:
                params_str = [ xtick_formatters[stype](param) for param in params]
            else:
                params_str = [r'$%s$' % param for param in params]

        #print "PCL file:", cl_method, stype
        #fd = open('%s_%s.pcl' % (cl_method, stype),'w')
        #print >>fd, self.build_pcl( family_keys, params_str, matrix)
        #print "------------------------------------------------"

        #if stype == 'bit_score':
        #    #pyplot.xscale('Exp10')
        #    pyplot.xscale('Exp10')
        
        #if stype=='nc_score': nticks = 20
        #else: nticks = 10
        #step = len(params_str) / nticks
        #d_param_locs = []
        #d_param_strs = []

        #skip = step
        #for (i, (loc, str)) in enumerate( zip( params, params_str)):
        #
        #    if nticks == 1:
        #        # add the last tick, and get out
        #        d_param_locs.append(len(params_str)-1)
        #        d_param_strs.append(params_str[-1])
        #        break

        #    # step ahead in the array
        #    if skip > 0:
        #        skip -= 1
        #        continue
        #
        #    d_param_locs.append(i)
        #    d_param_strs.append(str)
        #    skip = step

        ## add last tick
        #if stype in ('nc_score', 'bit_score'):
        #    pyplot.xticks( d_param_locs, d_param_strs, rotation='vertical')
        #else:
        #    pyplot.xticks(d_param_locs, d_param_strs)
        
        #print d_param_locs, d_param_strs

        nticks = 20
        step = int(numpy.ceil(float(len(params)) / nticks))
        indices = range(0, len(params), step)

        pyplot.xticks( numpy.array(params)[indices], numpy.array(params_str)[indices],
                       rotation='vertical')

        pyplot.axis('tight')
        
        #str_formatter = pyplot.FuncFormatter( param_format)
        #str_formatter = pyplot.FixedFormatter( params_str)
        #str_formatter.set_locs(range(len(params_str)))
                               
        #str_locator = pyplot.AutoLocator()
        #str_locator = pyplot.FixedLocator(range(len(params_str)), 5)

        #ax.autoscale_view()

        #ax.xaxis.set_major_formatter( str_formatter)
        #ax.xaxis.set_major_locator( str_locator)

        # pcolor makes rectangles aligned left, so we need to extend
        # the x axis one more to the right
        xlimits=pyplot.xlim()
        #print xlimits
        # for some reason, this doesn't work when there is only one
        # element.  It extends one too far to the right.
        if len(params) > 1:  
            pyplot.xlim( xlimits[0], xlimits[-1]+1)

        print "xlim:", pyplot.xlim()
        print "xticks:", ["%s" % a for a in pyplot.xticks()]

        if False and cl_method == 'average_linkage' and stype=='e_value':
            from IPython.Shell import IPShellEmbed
            ipsh = IPShellEmbed()
            ipsh("in heatmap")
           
        return

    def build_pcl(self, family_keys, params, matrix):
        """write pcl file format, for use in Java TreeView"""
        
        # http://yfgdb.princeton.edu/pcl_format.txt
        # headers
        s = "NAME\tDESC\tGWEIGHT"
        for p in params: s += "\t%s" % p
        s += "\n"

        s += "EWEIGHT\t\t"
        for p in params: s += "\t1"
        s += "\n"

        for (i,family_key) in enumerate(family_keys):
            s += "%s\t%s\t1" % (family_key,family_key)  # name, desc, gweight
            for (j,p) in enumerate(params):
                s += '\t%0.4f' % matrix[i][j]
            s += "\n"

        return s

    def draw_comp_plotset( self, plot_list, stats, family_keys, symbols = None,
                       filename = None, dpi=100, draw_legend = True):

        pyplot.suptitle('F comparison')

        for (i, (cl_method, stype)) in enumerate( plot_list):
            pyplot.subplot(2,3,i+1)
            self.draw_comp_plot( cl_method, stype, stats, family_keys, symbols=symbols,
                                 draw_legend=draw_legend)
        
        #pyplot.subplots_adjust(left=0.1, right=1, top=1, bottom=0.13)

        pyplot.subplots_adjust(hspace=0.3, bottom=0.05, left=0.02, top=0.92, right=0.98)
        #pyplot.ion()  # enable figure updates

        if filename is not None:
            pyplot.savefig( 'figures/%s' % filename, dpi=dpi)
            pyplot.close()
        else:
            pyplot.show()

    def draw_comp_plot( self, cl_method, stype, stats, family_keys,
                        symbols = None, draw_legend = True):
        if symbols is None:
            symbols = ['b', 'g', 'r', 'c', 'y', 'm', 'k',
                       'b:', 'g:', 'r:', 'c:', 'y:', 'm:', 'k:',
                       'b--', 'g--', 'r--', 'c--', 'y--', 'm--', 'k--']

        # parameters (x-axis) should be shared among all family sets
        params = sorted(self.stat_dict[ family_keys[0] ][cl_method][stype].keys())

        symbol_ind = 0
        for family_set in family_keys:
            for (stat, adjust) in stats:
                series = []
                for param in params:
                    series.append(
                        self.stat_dict[ family_set][ cl_method][ stype][ param][ stat][adjust])

                pyplot.plot( params, series, symbols[symbol_ind],
                             label=family_set + '-' + stat)
                symbol_ind += 1
                symbol_ind = symbol_ind % len(symbols)

        if cl_method in ('single_linkage','complete_linkage') and stype=='e_value':
            # reversed log axis
            pyplot.xscale('log', basex=10)
            (left, right) = pyplot.xlim()
            pyplot.xlim(right, left)
            #print cl_method, stype, left, right, pyplot.xlim()

        if cl_method=='mcl':
            title = "MCL (%s)" % stype
            xlabel = "Inflation"
        else:
            title = cl_method
            xlabel=stype

        pyplot.title( title)
        pyplot.xlabel( xlabel)

        #pyplot.xticks( range(len(params)), params, rotation=45, ha='right')

        #pyplot.axis([0, len(params)+6, 0, 1.05])
        pyplot.axis('tight')
        pyplot.yticks( [ float(a)/10 for a in range(0,11,1)] )
        pyplot.grid(color='gray', alpha=0.5)
        pyplot.ylim( -0.01, 1.01)

        if draw_legend:
            l = pyplot.legend(pad=0)
            # l.legendPatch.set_alpha(0.5)
            #l.legendPatch.set(fill=True, facecolor='gray', edgecolor='gray', alpha=0.9)
            l.legendPatch.set(fill=True, facecolor='white', edgecolor='gray')
        

    def draw_pfam_scatter(self, cr_id, ptype='identity', thresh=None, filename=None):
        """type is either 'identity' or 'promiscuity'"""

        for pcl_method in self.cluster_variants:
            for pcl_stype in self.cluster_variants[pcl_method]:
                if cr_id in self.cluster_variants[pcl_method][pcl_stype]:
                    stype = pcl_stype
                    cl_method = pcl_method
                    break

        c = self.clusterings[cr_id]
        if isinstance( c, cluster_stat.hclust_stat):
            assert thresh is not None, "draw_pfam_scatter: parameter 'thresh' required for hclust"

            # nc_score must be made a distance for fetch_cut()
            if stype=='nc_score': dist = 1 - thresh
            else: dist = thresh

            title = "%s: %s, dist: %s" % (cl_method, stype, thresh)
            c.fetch_cut( dist)
        else:
            inflation = c.csql.cluster_params
            inflation = eval(inflation)['main_inflation']
            title = "%s: %s, inflation: %f" % (cl_method, stype, inflation)

        (domain_list, cnt_array) = c.pfam_clust_cnt()

        clust_cnt = cnt_array[:,0]
        if ptype=='identity':
            param = cnt_array[:,1]
            symbol = 'b'
        elif ptype == 'promiscuity':
            param = cnt_array[:,2]
            symbol = 'g'
        else:
            assert False, "draw_pfam_scatter: unknown type: %s" % ptype 

        #print cr_id, ptype, thresh, filename, dist
        #print len(clust_cnt), len(param)
        pyplot.scatter( clust_cnt, param, marker='o')

                
        pyplot.title( title)
        pyplot.xlabel( '# Clusters with domain')
        pyplot.ylabel( ptype)
        
        pyplot.axis('tight')
        pyplot.yticks( [ float(a)/10 for a in range(0,11,1)] )
        pyplot.grid(color='gray', alpha=0.5)
        pyplot.ylim( -0.01, 1.01)

        if filename is not None:
            pyplot.savefig( 'figures/%s' % filename)
            pyplot.close()
        else:
            pyplot.show()

        
    def max_F(self, family_set, cl_method, stype, adjust=False):
        f_max = None
        param_max = None
        sd = self.stat_dict
        for param in sd[family_set][cl_method][stype]:
            f = sd[family_set][cl_method][stype][param][False]['F']
            if f_max is None or f > f_max:
                f_max = f
                param_max = param

        if adjust:
            f_max = f_max - sd[family_set][cl_method][stype][param_max][True]['F']
        
        return f_max


    def max_F_table(self, adjust=False):
        s  = "==================================================\n"
        s += "                 Maximum F: (adjust=%s)\n" % adjust
        s += "==================================================\n"
        s += "Family       SL_e   SL_nc   MCL_e  MCL_nc\n"
        for family_set in sorted(self.family_sets):
            s += "%8s   %0.4f  %0.4f  %0.4f  %0.4f\n" % (
                family_set,
                self.max_F( family_set, 'single_linkage', 'e_value', adjust),
                self.max_F( family_set, 'single_linkage', 'nc_score', adjust),
                self.max_F( family_set, 'mcl', 'e_value', adjust),
                self.max_F( family_set, 'mcl', 'nc_score', adjust))
            
        s  += "==================================================\n"
        return s
    
    
