#!/usr/bin/env python

# Jacob Joseph
# 6 Oct 2011

# Consider a variety of schemes for differentiating and classifying
# clusters or domains based upon the mutual information of both

import os, cPickle
import numpy
import time
import matplotlib
from matplotlib import pyplot
from DurandDB import pfamq, familyq, blastq
from JJcluster import cluster_sql, describe
#from information import mutualinfo
from mi_plot import mi_plot

#class mi_factor(mutualinfo):
class mi_factor(mi_plot):

    # Per cluster:
    # -------------------
    # cluster size
    # cluster maximum
    # num distinct domains
    # min, mean, median, max domain MI
    # min, mean, median, max domain pSMI
    # max domain MI - pSMI of that domain
    # min, mean, median, max: domain MI - pSMI
    # cluster maximum - best MI
    # cluster maximum - best pSMI


    def __init__(self, clustering, annotation,
                 cluster_run_id = None, clustering_type=None):
        #mutualinfo.__init__(self, clustering, annotation)
        mi_plot.__init__(self, clustering, annotation)

        self.bq = blastq.blastq()
        
        self.factors = [
            ('cluster_size', self.clfn_cluster_size),    # cluster size
            #('cluster_max', self.clfn_max_cluster_mi),    # cluster maximum
            ('num_domains', self.clfn_num_domains),       # num distinct domains
            
            ('mean_num_domains', self.clfn_mean_num_domains),  # num distinct domains/sequence
            ('mean_seq_length', self.clfn_mean_sequence_length),

            # min, mean, median, max, and sum of domain MI
            #('min_domain_mi', self.clfn_min_domains_mi), 
            #('mean_domain_mi', self.clfn_mean_domains_mi),
            #('med_domain_mi', self.clfn_med_domains_mi),
            ('max_domain_mi', self.clfn_max_domains_mi),
            #('sum_domain_mi', self.clfn_sum_domains_mi),

            # min, mean, median, max, sum domain pSMI
            #('min_domain_psmi', self.clfn_min_domain_psmi), 
            #('mean_domain_psmi', self.clfn_mean_domain_psmi),
            ('med_domain_psmi', self.clfn_med_domain_psmi),
            #('max_domain_psmi', self.clfn_max_domain_psmi),
            #('sum_domain_psmi', self.clfn_sum_domain_psmi),

            # min, mean, median, max, sum:  domain MI - pSMI
            #('min_domain_mi_psmi', self.clfn_min_domain_mi_psmi), 
            #('mean_domain_mi_psmi', self.clfn_mean_domain_mi_psmi),
            ('med_domain_mi_psmi', self.clfn_med_domain_mi_psmi),
            #('max_domain_mi_psmi', self.clfn_max_domain_mi_psmi),
            #('sum_domain_mi_psmi', self.clfn_sum_domain_mi_psmi),

            # min, mean, median, max, sum: Cluster maximum - domain pSMI in that cluster
            #('min_clmax_less_psmi', self.clfn_min_clmax_less_psmi), 
            #('mean_clmax_less_psmi', self.clfn_mean_clmax_less_psmi),
            ('med_clmax_less_psmi', self.clfn_med_clmax_less_psmi),
            #('max_clmax_less_psmi', self.clfn_max_clmax_less_psmi),
            #('sum_clmax_less_psmi', self.clfn_sum_clmax_less_psmi),

            # min, mean, median, max, sum: Cluster maximum - domain MI
            #('min_clmax_less_mi', self.clfn_min_clmax_less_mi),
            #('mean_clmax_less_mi', self.clfn_mean_clmax_less_mi),
            ('med_clmax_less_mi', self.clfn_med_clmax_less_mi),
            #('max_clmax_less_mi', self.clfn_max_clmax_less_mi),
            #('sum_clmax_less_mi', self.clfn_sum_clmax_less_mi),
            
            # max domain MI - pSMI of that domain?
            ]               

        # Use additional statistics if a cluster run was provided
        if False and cluster_run_id is not None and clustering_type is not None:

            self.cluster_desc = describe.describe(cluster_run_id=cluster_run_id,
                                                  clustering_type=clustering_type,
                                                  cacheq=True)

            self.factors.extend([
                # Sequence similarity of the cluster
                ('min_bit_score', lambda cluster: self.clfn_cl_stat(cluster, statkey='min', stype='bit_score')),
                ('max_bit_score', lambda cluster: self.clfn_cl_stat(cluster, statkey='max', stype='bit_score')),
                ('mean_bit_score',lambda cluster: self.clfn_cl_stat(cluster, statkey='mean', stype='bit_score')),
                ('stdev_bit_score', lambda cluster: self.clfn_cl_stat(cluster, statkey='mean', stype='bit_score')),
                ('density_bit_score', lambda cluster: self.clfn_cl_stat(cluster, statkey='density', stype='bit_score')),
                ('frac_edges_bit_score', lambda cluster: self.clfn_cl_stat(cluster, statkey='frac_edges', stype='bit_score')),

                # Score of the clustering (usually NC, but could be anything)
                ('min_score', lambda cluster: self.clfn_cl_stat(cluster, statkey='min')),
                ('max_score', lambda cluster: self.clfn_cl_stat(cluster, statkey='max')),
                ('mean_score',lambda cluster: self.clfn_cl_stat(cluster, statkey='mean')),
                ('stdev_score', lambda cluster: self.clfn_cl_stat(cluster, statkey='mean')),
                ('density', lambda cluster: self.clfn_cl_stat(cluster, statkey='density')),
                ('frac_edges', lambda cluster: self.clfn_cl_stat(cluster, statkey='frac_edges')),
                ])


        self.fact_index = dict( (f[0], i) for i,f in enumerate(self.factors))

        #mi_pickle = "$HOME/tmp/pickles/pca_analysis_109.pickle"
        #mi_pickle = "$HOME/tmp/pickles/pca_analysis_492_109.pickle"
        #mi_pickle = "$HOME/tmp/pickles/pca_analysis_493_112.pickle"
        #mi_pickle = "$HOME/tmp/pickles/pca_analysis_494_115.pickle"
        #mi_pickle = "$HOME/tmp/pickles/pca_analysis_114.pickle"
        mi_pickle = "$HOME/tmp/pickles/pca_analysis_117.pickle"
        #mi_pickle = "$HOME/tmp/pickles/pca_analysis_111.pickle"
        #mi_pickle = "$HOME/tmp/pickles/pca_analysis_116.pickle"
        #self.fetch_mi_matricies()
        #self.fetch_mi_matricies_pickle(mi_pickle)
        self.fetch_mi_matricies_unpickle(mi_pickle)
        
        #factor_pickle = "$HOME/tmp/pickles/pca_calculate_factor_matrix_109.pickle"
        #factor_pickle = "$HOME/tmp/pickles/pca_calculate_factor_matrix_492_109.pickle"
        #factor_pickle = "$HOME/tmp/pickles/pca_calculate_factor_matrix_493_112.pickle"
        #factor_pickle = "$HOME/tmp/pickles/pca_calculate_factor_matrix_494_115.pickle"
        #factor_pickle = "$HOME/tmp/pickles/pca_calculate_factor_matrix_114.pickle"
        factor_pickle = "$HOME/tmp/pickles/pca_calculate_factor_matrix_117.pickle"
        #factor_pickle = "$HOME/tmp/pickles/pca_calculate_factor_matrix_111.pickle"
        #factor_pickle = "$HOME/tmp/pickles/pca_calculate_factor_matrix_116.pickle"
        #self.calculate_factor_matrix()
        #self.calculate_factor_matrix_pickle(factor_pickle)
        self.calculate_factor_matrix_unpickle(factor_pickle)
        
        #self.calculate_pca()
        
    def clfn_cluster_size(self, cluster):
        return len(self.cluster_to_seq[cluster])

    def clfn_num_domains(self, cluster):
        return len(self.annotations_in_cluster(cluster))

    def clfn_mean_num_domains(self, cluster):
        return ( float(len(self.annotations_in_cluster(cluster))) /
                 len(self.cluster_to_seq[cluster]))

    def clfn_mean_sequence_length(self, cluster):
        # Cache successive calls for the same cluster
        if ('_cache_mean_sequence_length' in self.__dict__
            and self._cache_mean_sequence_length[0] == cluster):
            return self._cache_mean_sequence_length[1]

        seqs = tuple(self.cluster_to_seq[cluster])

        q = """SELECT AVG(length)
        FROM prot_seq_str
        JOIN prot_seq USING (seq_str_id)
        WHERE seq_id in %(seqs)s"""

        avg_length = float(self.bq.dbw.fetchsingle(q, locals()))
        self._cache_mean_sequence_length = (cluster, avg_length)
        return avg_length


    def clfn_cl_stat(self, cluster, statkey, stype=None):
        # Cache successive calls for the same cluster
        if ('_cache_cl_stat' in self.__dict__
            and self._cache_cl_stat[0] == (cluster, stype)):
            cl_stat = self._cache_cl_stat[1]
        else:
            cl_stat = self.cluster_desc.cluster_stats( cluster, stype=stype)
            self._cache_cl_stat = ((cluster, stype), cl_stat)

        # This function returns negative values when missing.  Replace
        # these with numpy.nan.
        return cl_stat[statkey] if cl_stat[statkey] >= 0 else numpy.nan
        

    def cluster_domain_indicies(self, cluster):
        # Cache successive calls for the same cluster
        if ('_cache_cluster_domain_indicies' in self.__dict__
            and self._cache_cluster_domain_indicies[0] == cluster):
            return self._cache_cluster_domain_indicies[1]

        
        domains = self.annotations_in_cluster(cluster)
        domains_j = numpy.array(
            [self.ann_index[domain] for domain in domains],
            dtype = numpy.int)

        #print cluster, domains, domains_j

        self._cache_cluster_domain_indicies = (cluster, domains_j)
            
        return domains_j
        
    def clfn_min_domains_mi(self, cluster):
        a = self.domain_mi_array[ self.cluster_domain_indicies(cluster)]
        if len(a) == 0: return numpy.nan
        return numpy.min(a)

    def clfn_mean_domains_mi(self, cluster):
        a = self.domain_mi_array[ self.cluster_domain_indicies(cluster)]
        if len(a) == 0: return numpy.nan
        return numpy.mean(a)

    def clfn_med_domains_mi(self, cluster):
        a = self.domain_mi_array[ self.cluster_domain_indicies(cluster)]
        if len(a) == 0: return numpy.nan
        return numpy.median(a)

    def clfn_max_domains_mi(self, cluster):
        a = self.domain_mi_array[ self.cluster_domain_indicies(cluster)]
        if len(a) == 0: return numpy.nan
        return numpy.max(a)

    def clfn_sum_domains_mi(self, cluster):
        a = self.domain_mi_array[ self.cluster_domain_indicies(cluster)]
        if len(a) == 0: return numpy.nan
        return numpy.sum(a)

    def clfn_min_domain_psmi(self, cluster):
        i = self.cl_index[cluster]
        a = self.pos_mi_matrix[i][self.cluster_domain_indicies(cluster)]
        if len(a) == 0: return numpy.nan
        return numpy.min(a)

    def clfn_mean_domain_psmi(self, cluster):
        i = self.cl_index[cluster]
        a = self.pos_mi_matrix[i][self.cluster_domain_indicies(cluster)]
        if len(a) == 0: return numpy.nan
        return numpy.mean(a)

    def clfn_med_domain_psmi(self, cluster):
        i = self.cl_index[cluster]
        a = self.pos_mi_matrix[i][self.cluster_domain_indicies(cluster)]
        if len(a) == 0: return numpy.nan
        return numpy.median(a)

    def clfn_max_domain_psmi(self, cluster):
        i = self.cl_index[cluster]
        a = self.pos_mi_matrix[i][self.cluster_domain_indicies(cluster)]
        if len(a) == 0: return numpy.nan
        return numpy.max(a)

    def clfn_sum_domain_psmi(self, cluster):
        i = self.cl_index[cluster]
        a = self.pos_mi_matrix[i][self.cluster_domain_indicies(cluster)]
        if len(a) == 0: return numpy.nan
        return numpy.sum(a)

    def clfn_min_domain_mi_psmi(self, cluster):
        i = self.cl_index[cluster]
        domains_j = self.cluster_domain_indicies(cluster)
        if len(domains_j) == 0: return numpy.nan
        return numpy.min( self.domain_mi_array[domains_j] -
                          self.pos_mi_matrix[i][domains_j] )
                                               
    def clfn_mean_domain_mi_psmi(self, cluster):
        i = self.cl_index[cluster]
        domains_j = self.cluster_domain_indicies(cluster)
        if len(domains_j) == 0: return numpy.nan
        return numpy.mean( self.domain_mi_array[domains_j] -
                           self.pos_mi_matrix[i][domains_j] )

    def clfn_med_domain_mi_psmi(self, cluster):
        i = self.cl_index[cluster]
        domains_j = self.cluster_domain_indicies(cluster)
        if len(domains_j) == 0: return numpy.nan
        return numpy.median( self.domain_mi_array[domains_j] -
                             self.pos_mi_matrix[i][domains_j] )

    def clfn_max_domain_mi_psmi(self, cluster):
        i = self.cl_index[cluster]
        domains_j = self.cluster_domain_indicies(cluster)
        if len(domains_j) == 0: return numpy.nan
        return numpy.max( self.domain_mi_array[domains_j] -
                          self.pos_mi_matrix[i][domains_j] )

    def clfn_sum_domain_mi_psmi(self, cluster):
        i = self.cl_index[cluster]
        domains_j = self.cluster_domain_indicies(cluster)
        if len(domains_j) == 0: return numpy.nan
        return numpy.sum( self.domain_mi_array[domains_j] -
                          self.pos_mi_matrix[i][domains_j] )

    def clfn_max_cluster_mi(self, cluster):
        # Cache successive calls for the same cluster
        if ('_cache_max_cluster_mi' in self.__dict__
            and self._cache_max_cluster_mi[0] == cluster):
            return self._cache_max_cluster_mi[1]

        max_mi = self.max_cluster_mi( cluster)
        self._cache_max_cluster_mi = (cluster, max_mi)
        return max_mi

    def clfn_min_clmax_less_psmi(self, cluster):
        i = self.cl_index[cluster]
        domains_j = self.cluster_domain_indicies(cluster)
        if len(domains_j) == 0: return numpy.nan
        return numpy.min( self.clfn_max_cluster_mi(cluster) -
                          self.pos_mi_matrix[i][domains_j])

    def clfn_mean_clmax_less_psmi(self, cluster):
        i = self.cl_index[cluster]
        domains_j = self.cluster_domain_indicies(cluster)
        if len(domains_j) == 0: return numpy.nan
        return numpy.mean( self.clfn_max_cluster_mi(cluster) -
                          self.pos_mi_matrix[i][domains_j])

    def clfn_med_clmax_less_psmi(self, cluster):
        i = self.cl_index[cluster]
        domains_j = self.cluster_domain_indicies(cluster)
        if len(domains_j) == 0: return numpy.nan
        return numpy.median( self.clfn_max_cluster_mi(cluster) -
                             self.pos_mi_matrix[i][domains_j])

    def clfn_max_clmax_less_psmi(self, cluster):
        i = self.cl_index[cluster]
        domains_j = self.cluster_domain_indicies(cluster)
        if len(domains_j) == 0: return numpy.nan
        return numpy.max( self.clfn_max_cluster_mi(cluster) -
                          self.pos_mi_matrix[i][domains_j])

    def clfn_sum_clmax_less_psmi(self, cluster):
        i = self.cl_index[cluster]
        domains_j = self.cluster_domain_indicies(cluster)
        if len(domains_j) == 0: return numpy.nan
        return numpy.sum( self.clfn_max_cluster_mi(cluster) -
                          self.pos_mi_matrix[i][domains_j])

    def clfn_min_clmax_less_mi(self, cluster):
        domains_j = self.cluster_domain_indicies(cluster)
        if len(domains_j) == 0: return numpy.nan
        return numpy.min( self.clfn_max_cluster_mi(cluster) -
                          self.domain_mi_array[domains_j])

    def clfn_mean_clmax_less_mi(self, cluster):
        domains_j = self.cluster_domain_indicies(cluster)
        if len(domains_j) == 0: return numpy.nan
        return numpy.mean( self.clfn_max_cluster_mi(cluster) -
                          self.domain_mi_array[domains_j])

    def clfn_med_clmax_less_mi(self, cluster):
        domains_j = self.cluster_domain_indicies(cluster)
        if len(domains_j) == 0: return numpy.nan
        return numpy.median( self.clfn_max_cluster_mi(cluster) -
                             self.domain_mi_array[domains_j])

    def clfn_max_clmax_less_mi(self, cluster):
        domains_j = self.cluster_domain_indicies(cluster)
        if len(domains_j) == 0: return numpy.nan
        return numpy.max( self.clfn_max_cluster_mi(cluster) -
                          self.domain_mi_array[domains_j])

    def clfn_sum_clmax_less_mi(self, cluster):
        domains_j = self.cluster_domain_indicies(cluster)
        if len(domains_j) == 0: return numpy.nan
        return numpy.sum( self.clfn_max_cluster_mi(cluster) -
                          self.domain_mi_array[domains_j])


    def fetch_mi_matricies(self):
        (self.domain_mi_array,
         self.pos_mi_matrix,
         self.neg_mi_matrix,
         self.cl_index,
         self.ann_index ) = self.mutual_information_specific_array()

        self.cl_index_inv = dict( (v,k) for (k,v) in self.cl_index.items())

        return

    def fetch_mi_matricies_unpickle(self, fname='$HOME/tmp/pickles/pca_analysis.pickle'):

        (self.domain_mi_array,
         self.pos_mi_matrix,
         self.neg_mi_matrix,
         self.cl_index,
         self.ann_index) = cPickle.load( open( os.path.expandvars(fname)))

        self.cl_index_inv = dict( (v,k) for (k,v) in self.cl_index.items())

        return

    def fetch_mi_matricies_pickle(self, fname='$HOME/tmp/pickles/pca_analysis.pickle'):

        fd = open( os.path.expandvars(fname), 'wb')

        cPickle.dump( (self.domain_mi_array,
                       self.pos_mi_matrix,
                       self.neg_mi_matrix,
                       self.cl_index,
                       self.ann_index),
                      fd, protocol=2)
        fd.close()
        
        return


    def calculate_factor_matrix(self):
        self.cluster_fact = numpy.empty( (len(self.cl_index), len(self.fact_index)),
                                         dtype=numpy.float64)

        for cluster, i in self.cl_index.items():
            for fact, fact_fn in self.factors:
                j = self.fact_index[fact]

                self.cluster_fact[i][j] = fact_fn( cluster)
        return

    def write_factor_matrix_orange(self, filename):
        """Output the factor matrix such that it may be read with
        Orage.data.Table"""

        fd = open(filename, 'w')

        # header
        s_name = ""
        s_type = ""
        s_class = ""
        for factor, i in sorted( self.fact_index.items(), key=lambda a: a[1]):

            if i > 0:
                for s in (s_name, s_type, s_class):
                    s += "\t"

            s_name += "%s" % factor
            s_type += "c"   # continuous variable 'c', or discrete 'd'
            s_class += ""   # one will be labelled "class"

        s_name += "\tcluster_id\n"
        s_type += "\td\n"
        s_class += "\tclass\n"

        for s in (s_name, s_type, s_class):
            fd.write(s)

        assert False, "unimplemented"


    def calculate_factor_matrix_pickle(self, fname='$HOME/tmp/pickles/pca_calculate_factor_matrix.pickle'):
        fd = open(os.path.expandvars(fname), 'wb')

        cPickle.dump( self.cluster_fact, fd, protocol=2)

        fd.close()
        return

    def calculate_factor_matrix_unpickle(self, fname='$HOME/tmp/pickles/pca_calculate_factor_matrix.pickle'):
        self.cluster_fact = cPickle.load( open(os.path.expandvars(fname)))

        return

    def calculate_pca(self, sample_fraction=None):
        
        # Including NaNs in the input to the numpy PCA sometimes
        # doesn't converge.

        #print "cluster_fact length:", len(self.cluster_fact)

        factor_matrix = self.cluster_fact

        # HACK 1: Remove clusters which have no domains.
        nonzero_clusters = factor_matrix[:, self.fact_index['num_domains']].nonzero()
        factor_matrix = factor_matrix[nonzero_clusters]
        
        #print "factor_matrix length:", len(factor_matrix)

        # HACK: Remove clusters with nan cluster_statistics
        for fact in ('density', 'density_bit_score', 'frac_edges_bit_score'):
            if not fact in self.fact_index: continue
                
            arenan = numpy.isnan(factor_matrix[:, self.fact_index[fact]])
            nonzero_clusters = (~arenan).nonzero()[0]
            factor_matrix = factor_matrix[nonzero_clusters]

        #print "factor_matrix length:", len(factor_matrix)

        # HACK: Set Nan values to 0
        #factors = self.cluster_fact[:]  # copy
        #nanarr = numpy.isnan(factors)
        #factors[nanarr] = 0

        # Sample clusters; e.g., for robustness tests
        if sample_fraction is not None:
            assert 0 < sample_fraction <= 1, "sample fraction (%g) should be > 0, <= 1" % sample_fraction
            
            indicies = numpy.random.permutation( factor_matrix.shape[0])
            indicies = indicies[:int(len(indicies) * sample_fraction)]

            # reordering the matrix shouldn't change the PCA result,
            # but I imagine it's nicer to slice in contiguous order.
            indicies.sort()
            factor_matrix = factor_matrix[indicies]

        self.pca = matplotlib.mlab.PCA(factor_matrix)

        # Don't use self.pca.Y, as the PCA may have been done with
        # truncated data, as to avoid SVD non-convergence, or
        # reordered, above.  Re-project the original data.
        self.projection = self.pca.project(self.cluster_fact)


    def pca_component_stability(self, nsamples, sample_fraction, ncomponents=5):

        Wt_array = numpy.zeros( (nsamples, ncomponents, len(self.fact_index)))
        frac_array = numpy.zeros( (nsamples, len(self.fact_index)))

        for i in range(nsamples):
            self.calculate_pca( sample_fraction)

            Wt_array[i] = self.pca.Wt[:ncomponents, :]     # weight of factors in component
            frac_array[i] = self.pca.fracs    # fraction of variance explained by component

        # restore the class to the full PCA calculation
        self.calculate_pca()

        return (Wt_array, frac_array)


    def plot_pca_component_scree(self, ncomponents=None,
                                 title_prefix=None, filename=None,
                                 dpi=None, nsamples=None, sample_fraction=None,
                                 subplots_adjust=None,
                                 draw_title=True):
        """A 'scree' plot of the variances explained by each component"""

        pyplot.figure()

        if draw_title:
            pyplot.title("%sPCA Component 'scree' plot: %s" % (
                "" if title_prefix is None else title_prefix+"\n",
                "%d samples, %g proportion" % (nsamples, sample_fraction) if nsamples is not None else ""))

        if ncomponents is None:
            # show components with variance >= 0.01
            ncomponents = (self.pca.fracs < 0.01).nonzero()[0][0]

        # sample, keeping components that explained more than 0.01
        # variance in the full set
        if nsamples is not None and sample_fraction is not None:
            Wt_array, frac_array = self.pca_component_stability( nsamples, sample_fraction,
                                                                 ncomponents = ncomponents)
            fracs_mean = numpy.mean( frac_array, axis=0)
            fracs_std = numpy.std( frac_array, axis=0)
        else:
            fracs_mean = self.pca.fracs[:ncomponents + 1]
            fracs_std = None

        pyplot.errorbar( range(len(fracs_mean)), fracs_mean,
                         marker = '',
                         yerr = fracs_std,
                         ecolor='black',
                         label='Component',
                         zorder=12,
                         markeredgewidth=0)

        pyplot.plot(numpy.cumsum(fracs_mean),
                    label="Cumulative contribution",
                    zorder=11)

        pyplot.ylim(-0.01,1.01)

        pyplot.xlabel("Principal component")
        pyplot.ylabel("Fraction of variance")
        
        leg = pyplot.legend(fancybox=True)
        leg.get_frame().set_alpha(0.6)
        leg.get_frame().set_facecolor('lightgrey')
        leg.set_zorder(14)

        if subplots_adjust is not None:
            pyplot.subplots_adjust(**subplots_adjust)
            
        pyplot.grid(color='grey',zorder=10)

        if filename is not None:
            pyplot.savefig( 'figures/%s' % filename, dpi=dpi)
            pyplot.close()
        else:
            pyplot.show()
                     
        return

        
    def plot_pca_factor_stability(self, nsamples, sample_fraction,
                                  comp_x = 0, comp_y = 1,
                                  title_prefix = None,
                                  filename = None,
                                  dpi = None,
                                  xlim = None, ylim = None,
                                  subplots_adjust=None,
                                  draw_title=True):
        pyplot.figure()

        if draw_title:
            pyplot.title("%sPCA factor stability: %d samples, %g proportion" % (
                "" if title_prefix is None else title_prefix+"\n",
                nsamples, sample_fraction))
        
        Wt_array, frac_array = self.pca_component_stability( nsamples, sample_fraction,
                                                             ncomponents=max(comp_x,comp_y)+1)

        # avg position over all samples
        Wt_avg = numpy.mean( Wt_array, axis=0)
        Wt_std = numpy.std( Wt_array, axis=0)
        frac_avg = numpy.mean( frac_array, axis=0)
        frac_std = numpy.std( frac_array, axis=0)
        
        # plot average position and stdev error bars
        pyplot.scatter( Wt_avg[comp_x],
                        Wt_avg[comp_y],
                        marker = 'o',
                        zorder=11
                        )
        pyplot.errorbar( Wt_avg[comp_x], Wt_avg[comp_y],
                         xerr=Wt_std[comp_x], yerr=Wt_std[comp_y],
                         ecolor='grey',
                         fmt=None,
                         zorder=12
                         )
        
        #pyplot.errorbar( Wt_avg[comp_x], Wt_avg[comp_y], , ecolor='grey', lw=0)

        for factor, i in self.fact_index.items():
            pyplot.text( Wt_avg[comp_x][i], Wt_avg[comp_y][i],
                         #factor,
                         reduce(lambda a,b: a+" "+b, factor.split('_')),
                         horizontalalignment='left',verticalalignment='top',
                         zorder=13,
                         fontsize=8)
            
        #pyplot.xlabel("Component %d. mean proportion: %6.4f (stdev: %6.4f)"  % (
        #    comp_x, frac_avg[comp_x], frac_std[comp_x]))
        #pyplot.ylabel("Component %d. mean proportion: %6.4f (stdev: %6.4f)"  % (
        #    comp_y, frac_avg[comp_y], frac_std[comp_y]))
        pyplot.xlabel("Component %d" % comp_x)
        pyplot.ylabel("Component %d" % comp_y)

        pyplot.grid(color='grey', alpha=0.5, zorder=10)

        pyplot.axis('tight')

        if subplots_adjust is not None:
            pyplot.subplots_adjust(**subplots_adjust)

        if xlim is not None:
            pyplot.xlim( xlim)

        if ylim is not None:
            pyplot.ylim( ylim)

        if filename is not None:
            pyplot.savefig( 'figures/%s' % filename, dpi=dpi)
            pyplot.close()
        else:
            pyplot.show()

        return (Wt_array, frac_array)
        
                                  
    def print_pca(self):
        m = "PCA data\n-----------------------------------\n"

        m += "# clusters: %d\n" % len(self.cluster_to_seq)
        m += "# clusters used for PCA: %d\n" % self.pca.numrows
        m += "# factors: %d\n" % self.pca.numcols

        m += "\nPrinciple Component Factor Fractions\n"
        m += "Factor                  0       1       2       3       4       5\n"
        m += "----------------------------------------------------------------------\n"
        for factor, i in sorted(self.fact_index.items(), key=lambda a: a[1]):
            m += "%20s % 6.4f % 6.4f % 6.4f % 6.4f % 6.4f % 6.4f\n" % (
                factor,
                self.pca.Wt[0][i],
                self.pca.Wt[1][i],
                self.pca.Wt[2][i],
                self.pca.Wt[3][i],
                self.pca.Wt[4][i],
                self.pca.Wt[5][i])

        m += "\n"
        m += "%20s % 6.4f % 6.4f % 6.4f % 6.4f % 6.4f % 6.4f\n" % (
            "Variance:",
            self.pca.fracs[0],
            self.pca.fracs[1],
            self.pca.fracs[2],
            self.pca.fracs[3],
            self.pca.fracs[4],
            self.pca.fracs[5])

        m += "\n" 
        m += "# Principle components: %d\n" % len(self.pca.fracs)

        return m

    def plot_pca_factors(self, num=3):
        pyplot.figure()
        for comp in range(num):
            pyplot.plot(self.pca.Wt[comp],
                        label="%d (%6.4f)" % (comp, self.pca.fracs[comp]))
            
        xlabels = sorted(self.fact_index, key=lambda a: a[1])
        pyplot.xticks( range(len(xlabels)), xlabels, rotation=-90)
        
        pyplot.legend()
        pyplot.subplots_adjust(bottom=0.3)
        pyplot.grid()
        return

    def plot_pca_data_scatter(self, comp_x=0, comp_y=1,
                              title_prefix = None,
                              filename = None,
                              dpi = None,
                              draw_title=True,
                              subplots_adjust=None,
                              cluster_categories = None,
                              draw_legend = True,
                              xlim = None,
                              ylim = None,
                              legend_loc = None,
                              color_by_factor = None,
                              draw_colorbar = True):
        fig = pyplot.figure()

        if draw_title:
            pyplot.title("%sPCA component scatter" % (
                "" if title_prefix is None else title_prefix+"\n"))

        # Don't use self.pca.Y, as the PCA may have been done with
        # truncated data, as to avoid SVD non-convergence.
        #projection = self.pca.project(self.cluster_fact)

        X = self.projection[:,comp_x]
        Y = self.projection[:,comp_y]

        if color_by_factor is not None:
            values = self.cluster_fact[:, self.fact_index[ color_by_factor]]
            norm = matplotlib.colors.LogNorm()
            normvalues = norm( values)
            colors = matplotlib.cm.hot( normvalues)

            if color_by_factor == 'num_domains':
                vmin = 1
                vmax = 10 ** int(numpy.log10( numpy.max(values)) + 1)
                norm = matplotlib.colors.LogNorm(vmin=vmin, vmax=vmax)
                cb_format = "%d"
                cb_ticks = None

            elif color_by_factor in ('cluster_max'):
                vmin = 1
                vmax = 10 ** int(numpy.log10( numpy.max(values)) + 1)
                norm = matplotlib.colors.LogNorm(vmin=vmin, vmax=vmax)
                cb_format = "%3.1f"
                cb_ticks = None

            elif color_by_factor in ('cluster_size'):
                vmin = 1
                vmax = 10 ** int(numpy.log10( numpy.max(values)) + 1)
                norm = matplotlib.colors.LogNorm(vmin=vmin, vmax=vmax)
                cb_format = "%d"
                cb_ticks = None

            elif color_by_factor == 'mean_num_domains':
                vmin = 0.1
                vmax = 10 ** int(numpy.log10( numpy.max(values)) + 1)
                #print vmin, vmax, numpy.max(values)
                norm = matplotlib.colors.LogNorm(vmin=vmin, vmax=vmax)
                cb_format = "%3.1f"
                cb_ticks = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1, 1.5, 2, 3, 4, 5, vmax]

            elif color_by_factor in ('med_clmax_less_psmi', 'med_domain_psmi'):
                norm = matplotlib.colors.LogNorm()
                cb_format = None
                cb_ticks = None

            else:
                norm = matplotlib.colors.Normalize()
                cb_format = None
                cb_ticks = None

            pyplot.scatter(X, Y, zorder=11, 
                           c=values,
                           cmap=pyplot.cm.hot,
                           norm = norm,
                           alpha=0.7,
                           edgecolor='none',
                           #facecolor=
                           )
        else:
            pyplot.scatter(X, Y, zorder=11)

        # Overlay cluster annotations
        if cluster_categories is not None:
            for (category, color, marker, clusters) in cluster_categories:
                
                cluster_inds = [self.cl_index[cluster] for cluster in clusters]

                X = self.projection[ cluster_inds, comp_x ]
                Y = self.projection[ cluster_inds, comp_y ]
                
                #color = matplotlib.cm.jet( float(i) / len(cluster_categories))
                


                pyplot.scatter(X, Y, label = category, 
                               marker = marker,
                               # highlight the first cluster in a category with a larger box
                               #s = [160] + [80]*(len(clusters)-1),
                               s = 100,
                               alpha=0.8,
                               #linestyle = ['solid'] + ['dotted']*(len(clusters)-1),
                               #linewidth = [3] + [1]*(len(clusters)-1),
                               linewidth = 1,
                               #facecolor = 'none',
                               facecolor = [color] + ['none']*(len(clusters)-1),
                               edgecolor = color,
                               zorder = 12)

                if draw_legend:
                    leg = pyplot.legend(fancybox=True, loc=legend_loc, ncol=2) # upper left
                    leg.get_frame().set_alpha(0.6)
                    leg.get_frame().set_facecolor('lightgrey')
                    leg.set_zorder(13)

        #pyplot.xlabel("Component %d (%6.4f)" % (comp_x, self.pca.fracs[comp_x]))
        #pyplot.ylabel("Component %d (%6.4f)" % (comp_y, self.pca.fracs[comp_y]))
        pyplot.xlabel("Component %d" % (comp_x))
        pyplot.ylabel("Component %d" % (comp_y))

        pyplot.axis('tight')
        fig.get_axes()[0].set_axis_bgcolor('grey')
        pyplot.grid(color='white', zorder=10)

        if xlim is not None:
            pyplot.xlim( xlim)

        if ylim is not None:
            pyplot.ylim( ylim)

        if subplots_adjust is not None:
            pyplot.subplots_adjust(**subplots_adjust)

        if draw_colorbar and color_by_factor is not None:
           pyplot.colorbar(fraction=0.05, pad=0.01, format=cb_format, ticks=cb_ticks)

        if filename is not None:
            pyplot.savefig( 'figures/%s' % filename, dpi=dpi)
            pyplot.close()
        else:
            pyplot.show()

        return

    def plot_factor_scatter(self, fact_x=None, fact_y=None,
                            title_prefix = None,
                            filename = None,
                            dpi = None):
        fig = pyplot.figure()
        pyplot.title("%sFactor scatterplot" % (
            "" if title_prefix is None else title_prefix+"\n"))

        X = self.cluster_fact[:, self.fact_index[fact_x]]
        Y = self.cluster_fact[:, self.fact_index[fact_y]]

        pyplot.scatter(X, Y)

        pyplot.xlabel("Factor: %s" % fact_x)
        pyplot.ylabel("Factor: %s" % fact_y)
        
        pyplot.axis('tight')
        fig.get_axes()[0].set_axis_bgcolor('lightgrey')
        pyplot.grid(color='grey')

        pyplot.subplots_adjust(left=0.08, bottom=0.07, right=0.97, top=0.9)

        if filename is not None:
            pyplot.savefig( 'figures/%s' % filename, dpi=dpi)
            pyplot.close()
        else:
            pyplot.show()

        return

    def plot_pca_data_heatmap(self, comp_x=0, comp_y=1,
                              title_prefix = None,
                              filename=None,
                              dpi = None,
                              hres=1000, vres=1000,
                              draw_title=True,
                              subplots_adjust=None
                              ):

        fig = pyplot.figure()

        if draw_title:
            pyplot.title("%sPCA component heatmap" % (
                "" if title_prefix is None else title_prefix+"\n"))

        # find the range of values in each axis
        X = self.projection[:,comp_x]
        Y = self.projection[:,comp_y]

        x_min = numpy.nanmin(X)
        x_max = numpy.nanmax(X)
        y_min = numpy.nanmin(Y)
        y_max = numpy.nanmax(Y)

        x_scale = float(hres - 1) / (x_max - x_min)
        y_scale = float(vres - 1) / (y_max - y_min)

        xy = numpy.zeros([vres, hres], dtype=numpy.int32)

        #FIXME.  This ought to be able to be done far more efficiently

        # Also, it considers only clusters with a non-nan PCA coordinates
        
        skipped = 0
        for i,(x,y) in enumerate(self.projection[:,[comp_x,comp_y]]):

            if numpy.isnan(x) or numpy.isnan(y):
                skipped += 1
                continue

            #print i,x,y,x_min,x_max,y_min,y_max,x_scale,y_scale
            
            # this always rounds down
            x_trans = int(x_scale * (x - x_min))
            y_trans = int(y_scale * (y - y_min))

            #print x_trans, y_trans
            # matplotlib.imshow uses opposite indicing from what is natural here
            xy[y_trans, x_trans] += 1

            #if x < 5 and y < -40:
            #    print i, x, y
            #    print self.projection[i,comp_x], self.projection[i, comp_y]
            #    print x_trans, y_trans
            #    print x_scale, y_scale
            #    print x_min, y_min
            #    print xy[x_trans, y_trans]

        #xy[ int(x_scale * (6 - x_min)), int(y_scale * (-36 - y_min))] = 100
        #print int(x_scale * (6 - x_min)), int(y_scale * (-36 - y_min))
        #xy[0,0] = 100
        #xy[10,0] = 100
        #xy[20,20] = 100
        #xy[5.5,5.5] = 100
        pyplot.imshow(xy, cmap=pyplot.cm.hot,
                      origin='lower',
                      norm = matplotlib.colors.LogNorm(),
                      interpolation='nearest',
                      zorder=11
                      )
        
        # fix axis limits perfectly
        pyplot.axis('image')

        # 10 ticks
        x_ticks = numpy.int32(numpy.linspace( x_min, x_max, 10))
        y_ticks = numpy.int32(numpy.linspace( y_min, y_max, 10))
        
        #x_ticks = numpy.arange( x_min, x_max, (x_max - x_min) / 10, dtype=numpy.int32)
        #y_ticks = numpy.arange( y_min, y_max, (y_max - y_min) / 10, dtype=numpy.int32)

        #print "x_min =", x_min
        #print "x_max =", x_max
        #print "x_scale =", x_scale
        #print "skipped:", skipped
        #print "x_ticks =", x_ticks

        # this is probably slighlty inaccruate
        pyplot.xticks( numpy.int32(x_scale * (x_ticks - x_min)), x_ticks)
        pyplot.yticks( numpy.int32(y_scale * (y_ticks - y_min)), y_ticks)

        fig.get_axes()[0].set_axis_bgcolor('lightgrey')
        pyplot.grid(color='grey', zorder=10)

        #pyplot.xlabel("Component %d (%6.4f)" % (comp_x, self.pca.fracs[comp_x]))
        #pyplot.ylabel("Component %d (%6.4f)" % (comp_y, self.pca.fracs[comp_y]))
        pyplot.xlabel("Component %d" % comp_x)
        pyplot.ylabel("Component %d" % comp_y)


        #pyplot.xlim(100,250)
        #pyplot.ylim(0,150)

        if subplots_adjust is not None:
            pyplot.subplots_adjust(**subplots_adjust)

        pyplot.colorbar(fraction=0.05, pad=0.01)

        if filename is not None:
            pyplot.savefig( 'figures/%s' % filename, dpi=dpi)
            pyplot.close()
        else:
            pyplot.show()

        return xy

    def plot_factor_heatmap(self, fact_x=None, fact_y=None,
                              title_prefix = None,
                              filename=None,
                              dpi = None,
                              hres=1000, vres=1000):

        fig = pyplot.figure()
        pyplot.title("%sFactor heatmap" % (
            "" if title_prefix is None else title_prefix+"\n"))

        facti_x = self.fact_index[fact_x]
        facti_y = self.fact_index[fact_y]

        # find the range of values in each axis
        X = self.cluster_fact[:, facti_x]
        Y = self.cluster_fact[:, facti_y]

        x_min = numpy.nanmin(X)
        x_max = numpy.nanmax(X)
        y_min = numpy.nanmin(Y)
        y_max = numpy.nanmax(Y)

        x_scale = float(hres - 1) / (x_max - x_min)
        y_scale = float(vres - 1) / (y_max - y_min)

        xy = numpy.zeros([vres, hres], dtype=numpy.int32)

        #FIXME.  This ought to be able to be done far more
        #efficiently.  Further, much can probably be handled with
        #imshow's extents parameter

        # Also, it considers only clusters with a non-nan PCA coordinates
        # print some information on the plot about how many were skipped

        skipped = 0
        for i,(x,y) in enumerate(self.cluster_fact[:,[facti_x, facti_y]]):

            # <0 for missing values of scores
            if numpy.isnan(x) or numpy.isnan(y) or x<0 or y<0:
                skipped += 1
                continue

            #print i,x,y,x_min,x_max,y_min,y_max,x_scale,y_scale
            
            # this always rounds down
            x_trans = int(x_scale * (x - x_min))
            y_trans = int(y_scale * (y - y_min))

            #print x, y, x_trans, y_trans

            #xy[x_trans, y_trans] += 1
            xy[y_trans, x_trans] += 1

        
        pyplot.imshow(xy, cmap=pyplot.cm.hot,
                      origin='lower',
                      norm = matplotlib.colors.LogNorm(),
                      interpolation='nearest'
                      )



        # fit axis limits 
        pyplot.axis('image')

        # 10 ticks
        x_ticks = numpy.linspace( x_min, x_max, 10)
        y_ticks = numpy.linspace( y_min, y_max, 10)
        
        #x_ticks = numpy.arange( x_min, x_max, (x_max - x_min) / 10, dtype=numpy.int32)
        #y_ticks = numpy.arange( y_min, y_max, (y_max - y_min) / 10, dtype=numpy.int32)

        #print "x_min =", x_min
        #print "x_max =", x_max
        #print "x_scale =", x_scale
        #print "skipped:", skipped
        #print "x_ticks =", x_ticks

        # this is probably slighlty inaccruate
        pyplot.xticks( numpy.int32(x_scale * (x_ticks - x_min)),
                       ["%4.3g" % f for f in x_ticks])
        pyplot.yticks( numpy.int32(y_scale * (y_ticks - y_min)),
                       ["%4.3g" % f for f in y_ticks])

        fig.get_axes()[0].set_axis_bgcolor('grey')
        pyplot.grid(color='white')

        pyplot.xlabel("Factor: %s" % fact_x)
        pyplot.ylabel("Factor: %s" % fact_y)

        #pyplot.xlim(100,250)
        #pyplot.ylim(0,150)

        pyplot.subplots_adjust( bottom=0.05, left=0.01, top=0.90, right=0.95)
        pyplot.colorbar(fraction=0.05, pad=0.01)

        if filename is not None:
            pyplot.savefig( 'figures/%s' % filename, dpi=dpi, transparent=False)
            pyplot.close()
        else:
            pyplot.show()

        return X, Y, xy


    def print_cluster_data(self, cluster_id):

        m = "Cluster: %s\n" % cluster_id

        m += "PCA Factors: %s\n" % self.projection[self.cl_index[cluster_id]]

        m += "# sequences: %d\n" % len( self.cluster_to_seq[cluster_id])

        annotations = self.annotations_in_cluster(cluster_id)
        m += "Domains: (%d total)\n" % len(annotations)
        
        m += "Domain    #seqs #seqsall #clust TotSize    Entropy         MI Name\n"
        m += "------------------------------------------------------------------\n"

        for domain in sorted(annotations):
            joint_cnt = self.joint_cluster_ann_count(domain)
            total_cluster_cnt = sum((len(self.cluster_to_seq[cluster])
                                     for cluster in joint_cnt))
            
            m += "%-10s %4d%9d%7d%8d %0.4e %0.4e %s\n" % (
                domain,
                joint_cnt[cluster_id],
                len(self.annotation_to_seq[domain]),
                len(joint_cnt),
                total_cluster_cnt,
                self.annotation_entropy(domain),
                self.mutual_information(domain),
                pq.lookup_domain_name(domain))
            
        return m
        

    def constrain_clusters( self, constraints):
        """Input.  A list of (component #, range_min, range_max)
        constraints.  Either of the ranges can be None."""

        bitmap = numpy.ones(len(self.cluster_to_seq), dtype=numpy.bool)

        for (component, range_min, range_max) in constraints:
            if range_min is not None:
                bitmap &= self.projection[:, component] > range_min
            if range_max is not None:
                bitmap &= self.projection[:, component] < range_max

        clusters = bitmap.nonzero()[0]
        clusters = [self.cl_index_inv[c] for c in clusters]
        
        return clusters


    def prune_factors(self, factors):
        """Prune all but the specified factors from factor_index and
        cluster_fact.  This is useful when re-using a pickled matrix
        of all.  Otherwise, don't spend the time calculating them in
        the first place."""

        # be careful with indexing, in case factors is unordered
        new_fact_index = dict( (k,i) for i,k in enumerate(factors))
        #print factors
        #print enumerate(factors)
        #print new_fact_index

        fact_i = [ self.fact_index[fact] for fact, newi in 
                   sorted(new_fact_index.items(), key=lambda a: a[1])]

        self.cluster_fact = self.cluster_fact[:, fact_i]
        self.fact_index = new_fact_index
        
        return

if __name__ == "__main__":

    pyplot.ioff() # turn off figure updates
    pyplot.rc('font', size=10, family = 'serif',
              serif = ['Computer Modern Roman']
              )

    # makes underscores a pain
    pyplot.rc('text', usetex=True)
    pyplot.rc('legend', fontsize = 10)
    
    pq = pfamq.pfamq()

    date = time.strftime('%Y%m%d')

    cr_id = None
    clustering_type = 'hierarchical'

    if False:  # Just the human and mouse families
        set_id = 109 # ppod human, mouse
        family_set_name = 'ppod_20_cleanqfo2'
        #family_set_name = 'hgnc_apr10_selected'

        title_prefix= "family_set_name: %s, set_id: %d" % (family_set_name,set_id)
        fname_prefix= date+"family_set_name_%s_set_id_%d" % (family_set_name,set_id)
        print title_prefix
        
        fq = familyq.familyq(family_set_name=family_set_name)
        clustering = fq.fetch_families()[0]
        
        annotation, seq_map = pq.fetch_domain_map(
            set_id = set_id,
            seq_ids = tuple(reduce(
            lambda a,b:a+b, [list(clustering[fam]) for fam in clustering])))

    # Full mouse and human set
    if False:
        set_id = 109   # ppod human, mouse
        cr_id = 475    # clustering using the full 600k set.  Correct log calculation
        nc_score = 0.425 # 0.2 is near the optimal on the family set, by F

        title_prefix = "set_id: %(set_id)d, cr_id: %(cr_id)d, nc_score: %(nc_score)g" % locals()
        fname_prefix = date+"_set_id_%(set_id)d_cr_id_%(cr_id)d_nc_score_%(nc_score)g" % locals()
        print title_prefix
        
        clusclass = cluster_sql.hcluster( cluster_run_id = cr_id,
                                          cacheq=True)
        clustering = clusclass.cut_tree( distance = 1-nc_score,
                                         set_id_filter = set_id)
        annotation, seq_map = pq.fetch_domain_map(set_id = set_id)

        # NC >= 0.425 human and mouse.  Cluster run 475
        family_categories = [
            ('ACSL (1)', 'red', 'o',
                     (22313659,
                      )),
            ('FGF (1)', 'red', 'v', (
                    22296838,
                    )),
            ('FOX (1)', 'red', '*', (
                    22288496,
                    )),
            ('Tbox (1)', 'red', 's', (
                    22305088,
                    )),
            ('TNF (11)', 'red', 'D', (
                    22314154, 22315129, 22270027, 22287177, 22261463, 22298474,
                    22170530, 22172871, 22244730, 22215648, 22291349,
                    )),
            ('WNT (1)', 'red', 'h', 
                    (22286892,
                     )),

            ('DVL (1)', 'cyan', 'o', (
                    22287331,
                    )),
            ('GATA (1)', 'cyan', 'v', (
                22311997,
                )),
            ('KIR (1)', 'cyan', '*', (
                22305526,
                )),
            ('Notch (1)', 'cyan', 's', (
                22314034,
                )),
            ('TRAF (1)', 'cyan', 'D', (
                22312291,
                )),
            ('USP (4)', 'cyan', 'h', (
                22313794, 22267491, 22301080, 22311108,
                )),
            
            ('ADAM (2)', 'orange', 'o', (
                22299624, 22305655,
                )),
            ('Kinase (31)', 'orange', 'v', (
                22315320, 22315736, 22303229, 22300635, 22303550, 22300112,
                22313490, 22302611, 22310422, 22305757, 22277080, 22260493,
                22299807, 22309362, 22307724, 22315149, 22292735, 22293439,
                22299098, 22301173, 22270377, 22280490, 22305008, 22306424,
                22308631, 22311141, 22309205, 22282324, 22300028, 22304974,
                22310621,
                )),
            ('Kinesin (4)', 'orange', '*', (
                22315618, 22315593, 21281259, 22311367,
                )),
            ('Myosin (1)', 'orange', 's', (
                22309053,
                )),
            ('Laminin (1)', 'orange', 'D', (
                22298677,
                )),
            ('PDE (4)', 'orange', 'h', (
                22310871, 22305574, 22062107, 22182893,
                )),
            ('SEMA (1)', 'orange', '^', (
                22313103,
                )),
            ('TNFR (10)', 'orange', 'p', (
                22315702, 22296245, 22279088, 22187182, 22185181, 22258821,
                22301119, 22233521, 22304998, 22309449,
                )),
            ]

    if False:
        set_id = 109   # ppod human, mouse
        cr_id = 492    # clustering using only human, mouse
        nc_score = 0.425 # 0.2 is near the optimal on the family set, by F

        title_prefix = "set_id: %(set_id)d, cr_id: %(cr_id)d, nc_score: %(nc_score)g" % locals()
        fname_prefix = date+"_set_id_%(set_id)d_cr_id_%(cr_id)d_nc_score_%(nc_score)g" % locals()
        print title_prefix
        
        clusclass = cluster_sql.hcluster( cluster_run_id = cr_id,
                                          cacheq=True)
        clustering = clusclass.cut_tree( distance = 1-nc_score,
                                         set_id_filter = set_id)
        annotation, seq_map = pq.fetch_domain_map(set_id = set_id)

        # NC >= 0.425 human and mouse clustering.  Cluster run 492
        family_categories = [
            ('ACSL (1)', 'red', 'o', (
                    21313189, 
                      )),
            ('FGF (1)', 'red', 'v', (
                    21315146,
                    )),
            ('FOX (1)', 'red', '*', (
                    21313208,
                    )),
            ('Tbox (1)', 'red', 's', (
                    21311153,
                    )),
            ('TNF (10)', 'red', 'D', (
                    21316487, 21272651, 21316821, 21299390, 21300913, 21310500, 
                    21307641, 21308486, 21315211, 21314609,
                    )),
            ('WNT (1)', 'red', 'h', (
                    21303249,
                     )),

            ('DVL (1)', 'cyan', 'o', (
                    21304630,
                    )),
            ('GATA (1)', 'cyan', 'v', (
                    21316577,
                )),
            ('KIR (1)', 'cyan', '*', (
                    21316514,
                )),
            ('Notch (1)', 'cyan', 's', (
                    21316692,
                )),
            ('TRAF (1)', 'cyan', 'D', (
                     21314576,
                )),
            ('USP (2)', 'cyan', 'h', (
                    21316737, 21315992, 21304961,  21311664,
                )),

            ('ADAM (2)', 'orange', 'o', (
                    21313069, 21236866,
                )),
            ('Kinase (26)', 'orange', 'v', (
                    21316419, 21316376, 21315534, 21277258, 21278407, 21314469, 
                    21316842, 21314831, 21316383, 21299608, 21301681, 21305535, 
                    21307676, 21308342, 21311039, 21284844, 21296022, 21306467, 
                    21316541, 21311060, 21316201, 21316768, 21297654, 21312761, 
                    21315620, 21314960,
                )),
            ('Kinesin (3)', 'orange', '*', (
                    21314100, 21276203, 21281259,
                )),
            ('Myosin (1)', 'orange', 's', (
                    21314369,
                )),
            ('Laminin (1)', 'orange', 'D', (
                    21316181,
                )),
            ('PDE (4)', 'orange', 'h', (
                    21313177, 21314552,  21285611,  21297469,
                )),
            ('SEMA (1)', 'orange', '^', (
                    21313870,
                )),
            ('TNFR (8)', 'orange', 'p', (
                    21316655, 21315830, 21305405, 21307847, 21308513, 21311412, 
                    21296209, 21316381,
                )),
            ]


    if False:
        set_id = 112   # ppod human
        cr_id = 493    # clustering using only human, mouse
        nc_score = 0.425 # 0.2 is near the optimal on the family set, by F

        title_prefix = "set_id: %(set_id)d, cr_id: %(cr_id)d, nc_score: %(nc_score)g" % locals()
        fname_prefix = date+"_set_id_%(set_id)d_cr_id_%(cr_id)d_nc_score_%(nc_score)g" % locals()
        print title_prefix
        
        clusclass = cluster_sql.hcluster( cluster_run_id = cr_id,
                                          cacheq=True)
        clustering = clusclass.cut_tree( distance = 1-nc_score,
                                         set_id_filter = set_id)
        annotation, seq_map = pq.fetch_domain_map(set_id = set_id)

        # NC >= 0.425 human only clustering.  Cluster run 493
        family_categories = [
            ('ACSL (1)', 'red', 'o', (
                     21242653,
                      )),
            ('FGF (1)', 'red', 'v', (
                    21244298,
                    )),
            ('FOX (1)', 'red', '*', (
                    21243581,
                    )),
            ('Tbox (1)', 'red', 's', (
                    21242169,
                    )),
            ('TNF (9)', 'red', 'D', (
                    21245688, 21245126, 21243962, 21215287, 21215304,
                    21229672, 21229676, 21229678, 21229679
                    )),
            ('WNT (1)', 'red', 'h', (
                    21237720,
                     )),

            ('DVL (1)', 'cyan', 'o', (
                    21238897,
                    )),
            ('GATA (1)', 'cyan', 'v', (
                    21245679,
                )),
            ('KIR (1)', 'cyan', '*', (
                    21243239,
                )),
            ('Notch (1)', 'cyan', 's', (
                    21245934,
                )),
            ('TRAF (1)', 'cyan', 'D', (
                    21244073,
                )),
            ('USP (4)', 'cyan', 'h', (
                 21245878, 21245218,  21244316,  21241838
                )),
            
            ('ADAM (1)', 'orange', 'o', (
                    21243312,             
                )),
            ('Kinase (23)', 'orange', 'v', (
                    21245461, 21245787, 21245933, 21244171, 21243163, 21244358, 
                    21244837, 21245056, 21245914, 21236259, 21241815, 21245147, 
                    21245524, 21245639, 21213153, 21220069, 21225284, 21225852, 
                    21226687, 21227798, 21228303, 21228595, 21229987
                )),
            ('Kinesin (1)', 'orange', '*', (
                    21244058,
                )),
            ('Myosin (1)', 'orange', 's', (
                    21243970,
                )),
            ('Laminin (1)', 'orange', 'D', (
                    21245934,
                )),
            ('PDE (4)', 'orange', 'h', (
                    21242946,  21244041,  21236130,  21225050
                )),
            ('SEMA (1)', 'orange', '^', (
                    21243430,
                )),
            ('TNFR (9)', 'orange', 'p', (
                    21245504, 21241020, 21245222, 21245428, 21229701, 21229704,
                    21229705, 21229822, 21229823
                )),
            ]

    if False:
        set_id = 115   # ppod mouse
        cr_id = 494    # clustering using only human, mouse
        nc_score = 0.425 # 0.2 is near the optimal on the family set, by F

        title_prefix = "set_id: %(set_id)d, cr_id: %(cr_id)d, nc_score: %(nc_score)g" % locals()
        fname_prefix = date+"_set_id_%(set_id)d_cr_id_%(cr_id)d_nc_score_%(nc_score)g" % locals()
        print title_prefix
        
        clusclass = cluster_sql.hcluster( cluster_run_id = cr_id,
                                          cacheq=True)
        clustering = clusclass.cut_tree( distance = 1-nc_score,
                                         set_id_filter = set_id)
        annotation, seq_map = pq.fetch_domain_map(set_id = set_id)

        # NC >= 0.425 mouse only clustering.  Cluster run 494
        family_categories = [
            ('ACSL (1)', 'red', 'o', (
                    21295164,
                      )),
            ('FGF (1)', 'red', 'v', (
                    21296685,
                    )),
            ('FOX (1)', 'red', '*', (
                    21294736,
                    )),
            ('Tbox (1)', 'red', 's', (
                    21293390,
                    )),
            ('TNF (10)', 'red', 'D', (
                    21297858, 21237649, 21246494, 21246528, 21272651, 21272667,
                    21272672, 21272673, 21296374, 21297944
                    )),
            ('WNT (1)', 'red', 'h', (
                    21288949,
                     )),

            ('DVL (1)', 'cyan', 'o', (
                    21290200,
                    )),
            ('GATA (1)', 'cyan', 'v', (
                    21298036,
                )),
            ('KIR (1)', 'cyan', '*', (
                    21297865,
                )),
            ('Notch (1)', 'cyan', 's', (
                    21298136,
                )),
            ('TRAF (1)', 'cyan', 'D', (
                    21296351,
                )),
            ('USP (2)', 'cyan', 'h', (
                    21298077,  21295032,
                )),

            ('ADAM (2)', 'orange', 'o', (
                    21294409, 21236866,
                )),
            ('Kinase (25)', 'orange', 'v', (
                    21297511, 21242865, 21245103, 21254660, 21263932, 21266608, 
                    21268818, 21269909, 21270469, 21273295, 21274691, 21277258, 
                    21278407, 21284351, 21287176, 21293652, 21294100, 21295548, 
                    21295621, 21296492, 21296731, 21296991, 21297697, 21297851, 
                    21297997
                )),
            ('Kinesin (4)', 'orange', '*', (
                    21295574, 21276203, 21281259, 21298163
                )),
            ('Myosin (1)', 'orange', 's', (
                    21295971,
                )),
            ('Laminin (1)', 'orange', 'D', (
                    21298136,
                )),
            ('PDE (4)', 'orange', 'h', (
                    21294829, 21263461, 21287346, 21295936
                )),
            ('SEMA (1)', 'orange', '^', (
                    21295544,
                )),
            ('TNFR (9)', 'orange', 'p', (
                    21298085, 21272710, 21272715, 21272718, 21272745, 21272977, 
                    21272978, 21297198, 21297720
                )),
            ]



    # S cerevisiae
    if False:
        set_id = 111
        cr_id = 475      # clustering using the full 600k set.  Correct log calculation
        nc_score = 0.425 # 0.2 is near the optimal on the family set, by F

        title_prefix = "set_id: %(set_id)d, cr_id: %(cr_id)d, nc_score: %(nc_score)g" % locals()
        fname_prefix = date+"_set_id_%(set_id)d_cr_id_%(cr_id)d_nc_score_%(nc_score)g" % locals()
        print title_prefix
        
        clusclass = cluster_sql.hcluster( cluster_run_id = cr_id,
                                          cacheq=True)
        clustering = clusclass.cut_tree( distance = 1-nc_score,
                                         set_id_filter = set_id)
        annotation, seq_map = pq.fetch_domain_map(set_id = set_id)

    # S purpuratus
    if False:
        set_id = 114
        cr_id = 475      # clustering using the full 600k set.  Correct log calculation
        nc_score = 0.425 # 0.2 is near the optimal on the family set, by F

        title_prefix = "set_id: %(set_id)d, cr_id: %(cr_id)d, nc_score: %(nc_score)g" % locals()
        fname_prefix = date+"_set_id_%(set_id)d_cr_id_%(cr_id)d_nc_score_%(nc_score)g" % locals()
        print title_prefix
        
        clusclass = cluster_sql.hcluster( cluster_run_id = cr_id,
                                          cacheq=True)
        clustering = clusclass.cut_tree( distance = 1-nc_score,
                                         set_id_filter = set_id)
        annotation, seq_map = pq.fetch_domain_map(set_id = set_id)

    # C elegans
    if False:
        set_id = 116
        cr_id = 475      # clustering using the full 600k set.  Correct log calculation
        nc_score = 0.425 # 0.2 is near the optimal on the family set, by F

        title_prefix = "set_id: %(set_id)d, cr_id: %(cr_id)d, nc_score: %(nc_score)g" % locals()
        fname_prefix = date+"_set_id_%(set_id)d_cr_id_%(cr_id)d_nc_score_%(nc_score)g" % locals()
        print title_prefix
        
        clusclass = cluster_sql.hcluster( cluster_run_id = cr_id,
                                          cacheq=True)
        clustering = clusclass.cut_tree( distance = 1-nc_score,
                                         set_id_filter = set_id)
        annotation, seq_map = pq.fetch_domain_map(set_id = set_id)


    # Drosophila
    if True:
        set_id = 117
        cr_id = 475      # clustering using the full 600k set.  Correct log calculation
        nc_score = 0.425 # 0.2 is near the optimal on the family set, by F

        title_prefix = "set_id: %(set_id)d, cr_id: %(cr_id)d, nc_score: %(nc_score)g" % locals()
        fname_prefix = date+"_set_id_%(set_id)d_cr_id_%(cr_id)d_nc_score_%(nc_score)g" % locals()
        print title_prefix
        
        clusclass = cluster_sql.hcluster( cluster_run_id = cr_id,
                                          cacheq=True)
        clustering = clusclass.cut_tree( distance = 1-nc_score,
                                         set_id_filter = set_id)
        annotation, seq_map = pq.fetch_domain_map(set_id = set_id)



    ########
    fact_inst = mi_factor( clustering, annotation,
                           cluster_run_id = cr_id,
                           clustering_type = clustering_type)
 
    if False:
        fact_inst.prune_factors(['med_clmax_less_mi', # -0.28, 0.1
                                 'mean_seq_length',   # 0.05, 0
                                 #'frac_edges_bit_score',
                                 'cluster_max',       # 0.05, 0.22
                                 'med_clmax_less_psmi', # 0.05, 0.22
                                 'num_domains',       # 0.05, 0.16 Must include for pca hack
                                 'mean_num_domains',  # 0.1, -0.02

                                 'med_domain_mi_psmi',# 0.28, -0.08
                                 #'med_domain_mi',     # 0.28, 0
                                 'max_domain_mi',     # 0.28, 0
                                 
                                 # relevant only to an actual clustering
                                 #'density_bit_score', # 0.05, -0.12                             
                                 #'density',           # -0.07, -0.25  #Must include for pca hack
                                 #'mean_score',
                                 ])
                             
    fact_inst.calculate_pca()

    #################################
    # Plotting
    #################################
    if False:
        comp_x = 0
        comp_y = 1
        lfname_prefix = fname_prefix + "_pca_data_scatter_families_%d_%d" % (comp_x, comp_y)

        pyplot.rc('figure', figsize=(6,5))
        fact_inst.plot_pca_data_scatter( 
            comp_x = comp_x,
            comp_y = comp_y,
            cluster_categories = family_categories,
            subplots_adjust = {'bottom': 0.08, 'top': 0.99, 'right': 0.99, 'left':0.09},
            draw_title = False,
            dpi = 600,
            filename=lfname_prefix + ".png",
            legend_loc = 'lower right'
            )

        pyplot.rc('figure', figsize=(3.2,3.2))
        fact_inst.plot_pca_data_scatter( 
            comp_x = comp_x,
            comp_y = comp_y,
            cluster_categories = family_categories,
            subplots_adjust = {'bottom': 0.11, 'top': 0.99, 'right': 0.98, 'left':0.17},
            draw_title = False,
            dpi = 600,
            xlim = (-1.1, 5.5),
            ylim = (-5.5, 6.5),
            filename=lfname_prefix + "_zoom0.png",
            draw_legend=False
            )

        fact_inst.plot_pca_data_scatter( 
            comp_x = comp_x,
            comp_y = comp_y,
            cluster_categories = family_categories,
            subplots_adjust = {'bottom': 0.11, 'top': 0.99, 'right': 0.98, 'left':0.17},
            draw_title = False,
            dpi = 600,
            xlim = (-1.1, 1.75),
            ylim = (-2.6, 3.25),
            filename=lfname_prefix + "_zoom1.png",
            draw_legend=False
            )

        fact_inst.plot_pca_data_scatter( 
            comp_x = comp_x,
            comp_y = comp_y,
            cluster_categories = family_categories,
            subplots_adjust = {'bottom': 0.11, 'top': 0.99, 'right': 0.98, 'left':0.17},
            draw_title = False,
            dpi = 600,
            xlim = (-0.86, 0),
            ylim = (-0.5, 0.65),
            filename=lfname_prefix + "_zoom2.png",
            draw_legend=False
            )
    
    if False:
        comp_x = 0
        comp_y = 1

        for fact in ('cluster_size',
                     'num_domains', 
                     'mean_num_domains',
                     'med_clmax_less_mi', 
                     'med_clmax_less_psmi', 
                     'med_domain_psmi'
                     ):
            lfname_prefix = fname_prefix + "_pca_data_scatter_%d_%d_%s" % (comp_x, comp_y, fact)

            pyplot.rc('figure', figsize=(6,5))
            fact_inst.plot_pca_data_scatter( 
                comp_x = comp_x,
                comp_y = comp_y,
                color_by_factor = fact,
                subplots_adjust = {'bottom': 0.08, 'top': 0.98, 'right': 0.93, 'left':0.09},
                draw_title = False,
                dpi = 600,
                filename=lfname_prefix + ".png",
                draw_legend = False
                )

            pyplot.rc('figure', figsize=(3.2,3.2))
            fact_inst.plot_pca_data_scatter( 
                comp_x = comp_x,
                comp_y = comp_y,
                color_by_factor = fact,
                subplots_adjust = {'bottom': 0.11, 'top': 0.99, 'right': 0.95, 'left':0.17},
                draw_title = False,
                dpi = 600,
                xlim = (-1.1, 5.5),
                ylim = (-5.5, 6.5),
                filename=lfname_prefix + "_zoom0.png",
                draw_colorbar=False
                )
            
            fact_inst.plot_pca_data_scatter( 
                comp_x = comp_x,
                comp_y = comp_y,
                color_by_factor = fact,
                subplots_adjust = {'bottom': 0.11, 'top': 0.99, 'right': 0.95, 'left':0.17},
                draw_title = False,
                dpi = 600,
                xlim = (-1.1, 1.75),
                ylim = (-2.6, 3.25),
                filename=lfname_prefix + "_zoom1.png",
                draw_colorbar=False
                )

            fact_inst.plot_pca_data_scatter( 
                comp_x = comp_x,
                comp_y = comp_y,
                color_by_factor = fact,
                subplots_adjust = {'bottom': 0.11, 'top': 0.99, 'right': 0.95, 'left':0.17},
                draw_title = False,
                dpi = 600,
                xlim = (-0.86, 0),
                ylim = (-0.5, 0.65),
                filename=lfname_prefix + "_zoom2.png",
                draw_colorbar=False
                )



    if False:
        nsamples = 1000
        fraction = 0.75

        pyplot.rc('figure', figsize=(6,3))
        
        fact_inst.plot_pca_component_scree(
            title_prefix=title_prefix,
            dpi=150,
            filename=fname_prefix+"_pca_component_scree.pdf",
            nsamples=nsamples, sample_fraction=fraction,
            draw_title=False,
            subplots_adjust={'bottom': 0.12, 'top': 0.99, 'right': 0.99, 'left':0.07})

    if True:
        nsamples = 1000
        fraction = 0.75

        pyplot.rc('figure', figsize=(6, 5))
        Wt_array, frac_array = fact_inst.plot_pca_factor_stability(
            nsamples, fraction,
            comp_x = 0,
            comp_y = 1,
            title_prefix = title_prefix, dpi=100,
            filename=fname_prefix+"_pca_factor_stability_%d_%f_%d_%d.pdf" % (
                nsamples, fraction, 0, 1),
            xlim = (-0.3, 0.5) if cr_id==475 and set_id==109 else (-0.1,0.6) if cr_id==475 and set_id==117 else None,
            ylim = (-0.25, 0.3) if cr_id==475 and set_id==109 else (-0.5,0.2) if cr_id==475 and set_id==117 else None,
            subplots_adjust={'bottom': 0.09, 'top': 0.985, 'right': 0.98, 'left':0.1},
            draw_title=False
            )

        pyplot.rc('figure', figsize=(3.2,3.2))
        for comp_x, comp_y in ((0,2), (1,2)):
            Wt_array, frac_array = fact_inst.plot_pca_factor_stability(
                nsamples, fraction,
                comp_x = comp_x,
                comp_y = comp_y,
                title_prefix = title_prefix, dpi=100,
                filename=fname_prefix+"_pca_factor_stability_%d_%f_%d_%d.pdf" % (
                nsamples, fraction, comp_x, comp_y),
                draw_title=False,
                subplots_adjust={'bottom': 0.11, 'top': 0.98, 'right': 0.91, 'left':0.19},
                xlim = (-0.3, 0.65) if cr_id==475 and set_id==109 else (-0.1,0.6) if cr_id==475 and set_id==117 and comp_x==0 else None,
                ylim = (-0.15, 0.2) if cr_id==475 and set_id==109 else None if cr_id==475 and set_id==117 else None,
                )

    if False:

        for comp_x, comp_y, figsize in ((0,1, (5.75,5)), 
                                        (0,2, (3.2,2.6)), 
                                        (1,2, (3.2,2.6))
                               ):
            if figsize[0] == 5.75:
                subplots_adjust={'bottom': 0.08, 'top': 0.99, 'right': 0.95, 'left':0.05}
            else:
                subplots_adjust={'bottom': 0.15, 'top': 0.99, 'right': 0.95, 'left':0.05}
            
            pyplot.rc('figure', figsize=figsize)
            xy = fact_inst.plot_pca_data_heatmap(
                comp_x = comp_x,
                comp_y = comp_y,
                title_prefix = title_prefix,
                dpi = 150,
                filename = fname_prefix + "_pca_heatmap_%d_%d.pdf" % (comp_x, comp_y),
                hres=100, vres=100,
                draw_title=False,
                subplots_adjust=subplots_adjust
                )

            #fact_inst.plot_pca_data_scatter( 
            #    comp_x = comp_x,
            #    comp_y = comp_y,
            #    title_prefix = title_prefix,
            #    dpi = 150,
            #    filename = fname_prefix + "_pca_data_scatter_%d_%d.pdf" % (comp_x, comp_y),
            #    draw_title=False,
            #    subplots_adjust={'bottom': 0.11, 'top': 0.99, 'right': 0.91, 'left':0.17},
            #    )
                
    if False:
        pyplot.rc('figure', figsize=(9,8))
        #for fact_x, fact_y in (('cluster_max','num_domains'),
        #                       ('cluster_max','mean_num_domains'),
        #                       ('cluster_max', 'max_domain_mi'),
        #                       ('max_domain_mi', 'med_clmax_less_mi'),
        #                       ('max_domain_mi', 'density'),
        #                       ('density', 'density_bit_score'),
        #                       ('density', 'cluster_max'),
        #                       ('density', 'mean_seq_length'),
        #                       ('density', 'mean_num_domains'),
        #                       ('density_bit_score', 'mean_seq_length'),
        #                       ):
        factors = fact_inst.fact_index.keys()
        for i, fact_x in enumerate(factors):
            for fact_y in factors[i+1:]:
                X, Y, xy = fact_inst.plot_factor_heatmap(
                    fact_x = fact_x,
                    fact_y = fact_y,
                    title_prefix = title_prefix,
                    dpi = 100,
                    filename = fname_prefix + "_factor_heatmap_%s_%s.png" % (fact_x, fact_y),
                    hres=100, vres=100)

                fact_inst.plot_factor_scatter(
                    fact_x = fact_x,
                    fact_y = fact_y,
                    title_prefix = title_prefix,
                    dpi = 100,
                    filename = fname_prefix + "_factor_scatter_%s_%s.png" % (fact_x, fact_y))


    if False:
        pyplot.rc('figure', figsize=(6.5,4))
        #pyplot.rc('text', usetex=False)  # breaks for some reason
        
        for comp in (1,2):
            order = fact_inst.projection[:,comp].argsort()
            clusters = [fact_inst.cl_index_inv[i] for i in order]
            f = fact_inst.scatter_cluster_mi_specific(
                cluster_list=clusters,
                dpi=300, # png
                grid=False,
                xticks=False,
                pointsize=0.5,
                xlimit=(-100, 10100), # fit the points well inside the axis frame
                filename = fname_prefix + "_scatter_cluster_mi_specific_pca%d.png" % comp
                )

        if False:
            pyplot.rc('figure', figsize=(6.5,4))
            #pyplot.rc('text', usetex=False)  # breaks for some reason

            fact_inst.scatter_cluster_mi_specific(
            cluster_list=clusters[:5300],
            dpi=300, # png
            grid=True,
            xticks=False,
            pointsize=0.5,
            xlimit=(-20, 5320), # fit the points well inside the axis frame
            filename = fname_prefix + "_scatter_cluster_mi_specific_pca0_0_5300.png"
            )

        fact_inst.scatter_cluster_mi_specific(
            cluster_list=clusters[:1000],
            dpi=300, # png
            grid=False,
            xticks=True,
            pointsize=0.5,
            xlimit=(-10, 1000), # fit the points well inside the axis frame
            filename = fname_prefix + "_scatter_cluster_mi_specific_pca0_0_1000.png")

        fact_inst.scatter_cluster_mi_specific(
            cluster_list=clusters[1000:2000],
            dpi=300, # png
            grid=True,
            xticks=False,
            pointsize=0.5,
            xlimit=(-10, 1000), # fit the points well inside the axis frame
            filename = fname_prefix + "_scatter_cluster_mi_specific_pca0_1000_2000.png")

        fact_inst.scatter_cluster_mi_specific(
            cluster_list=clusters[1500:1600],
            dpi=150, 
            grid=True,
            xticks=False,
            pointsize=5,
            xlimit=(-1, 101), # fit the points well inside the axis frame
            filename = fname_prefix + "_scatter_cluster_mi_specific_pca0_1500_1600.pdf")

        fact_inst.scatter_cluster_mi_specific(
            cluster_list=clusters[1600:1700],
            dpi=150, 
            grid=True,
            xticks=False,
            pointsize=10,
            xlimit=(-1, 101), # fit the points well inside the axis frame
            filename = fname_prefix + "_scatter_cluster_mi_specific_pca0_1600_1700.pdf")

        fact_inst.scatter_cluster_mi_specific(
            cluster_list=clusters[5000:5300],
            dpi=300, # png
            grid=True,
            xticks=False,
            pointsize=10,
            xlimit=(-5, 305), # fit the points well inside the axis frame
            filename = fname_prefix + "_scatter_cluster_mi_specific_pca0_5000_5300.png")

        fact_inst.scatter_cluster_mi_specific(
            cluster_list=clusters[5225:5300],
            dpi=300, # png
            grid=True,
            xticks=False,
            pointsize=10,
            xlimit=(-1, 76), # fit the points well inside the axis frame
            filename = fname_prefix + "_scatter_cluster_mi_specific_pca0_5225_5300.png")

        fact_inst.scatter_cluster_mi_specific(
            cluster_list=clusters[:45],
            dpi=150, # pdf
            grid=True,
            pointsize=10,
            xlimit=(-0.5, 44.5), # fit the points well inside the axis frame
            filename = fname_prefix + "_scatter_cluster_mi_specific_pca0_0_44.pdf"
            )

        fact_inst.scatter_cluster_mi_specific(
            cluster_list=clusters[4900:4950],
            dpi=150, # pdf
            grid=True,
            pointsize=10,
            xlimit=(-0.5, 44.5), # fit the points well inside the axis frame
            filename = fname_prefix + "_scatter_cluster_mi_specific_pca0_4900_4944.pdf"
            )

        fact_inst.scatter_cluster_mi_specific(
            cluster_list=clusters[5200:5250],
            dpi=150, # pdf
            grid=True,
            pointsize=10,
            xlimit=(-0.5, 44.5), # fit the points well inside the axis frame
            filename = fname_prefix + "_scatter_cluster_mi_specific_pca0_5200_5244.pdf"
            )

# projection = fact_inst.pca.project(fact_inst.cluster_fact)
# projection[:,0] > 60).nonzero()[0][0]
# fact_inst.cl_index_inv[6826]

# for cluster in fact_inst.constrain_clusters( [(0,None,0),(1,2,None)]):
#     print fact_inst.print_cluster_data(cluster)

# order = fact_inst.projection[:,0].argsort()
# clusters = [fact_inst.cl_index_inv[i] for i in order]
# fact_inst.scatter_cluster_mi_specific(cluster_list=clusters)

