#!/usr/bin/env python
# Jacob Joseph

from DurandDB import seqq
from JJutil import pickler, rate
#from scipy import sparse
import math, itertools, time, numpy

class blastq(seqq.seqq):
    """Class to access sequence sets and pairwise scores (e.g., BLAST,
    NC) from DurandLab2"""
    
    def __init__(self, **kwargs):
        seqq.seqq.__init__(self, **kwargs)

        self.memory_cache = {}

    def fetch_blast_info( self, br_id):
        "Return (set_id, num_sequences, num_residues) for a blast run."
        
        q = """SELECT set_id, num_sequences, num_residues
        FROM blast_run
        WHERE br_id = %(br_id)s"""

        return self.dbw.fetchone(q, locals())

    def fetch_blast_info_d( self, br_id):
        "Return dictionary of blast run information (from blast_run)."
        
        q = """SELECT * FROM blast_run
        WHERE br_id = %(br_id)s"""

        info_d = self.dbw.fetchone_d(q, locals())
        info_d['params'] = eval(info_d['params'], None, info_d)

        return info_d

    def fetch_blast_params( self, br_id):
        """Execute and return the blast params dictionary field"""
        info_d = self.fetch_blast_info_d( br_id)

        return info_d['params']

    def fetch_nc_info_d( self, nc_id):
        """Return dictionary of nc run information (from nc_run)."""
        
        q = """SELECT * FROM nc_run
        WHERE nc_id=%(nc_id)s"""

        return self.dbw.fetchone_d(q, locals())
        
    def fetch_db_size( self, br_id):
        "Return num_sequences for a blast run."
        return self.fetch_blast_info_d(br_id)['num_sequences']

    def fetch_max_bit_score(self, br_id=None,
                            symmetric=False,
                            query_seq_id=None,
                            seq_id_0_set_id=None,
                            blast_hit_rank=0,
                            ):
        """Fetch the maximum bit score.  Useful when a distance needs
        to be created from the (max(bit) - bit).  This function does
        not filter by target set_id, and so could be an overestimate
        of bit_score.  It can never be smaller, however."""

        q_base = "SELECT max(bit_score[1])"
        
        if symmetric: blast_table = "blast_hit_symmetric_arr"
        else: blast_table = "blast_hit_arr"
        q_from = "\nFROM %s" % blast_table

        q_where = "\nWHERE br_id=%(br_id)s"

        if not symmetric:
            q_where += "\nAND hit_rank=%(blast_hit_rank)s"

        if seq_id_0_set_id is not None:
            q_from += "\nJOIN prot_seq_set_member AS ps0 ON (ps0.seq_id = seq_id_0)"
            q_where += "\nAND ps0.set_id=%(seq_id_0_set_id)s"

        if query_seq_id is not None:
            q_where += "\nAND seq_id_0 IN %(query_seq_id)s"

            # FIXME: other types?  Handle this another way?
            if isinstance(query_seq_id, (list,set)):  
                query_seq_id = tuple(query_seq_id)
            else:
                query_seq_id = (query_seq_id,)

        q = q_base + q_from + q_where

        return self.dbw.fetchsingle(q, locals())
        

    def fetch_score_lists(self, stype=None, br_id=None, nc_id=None,
                          symmetric=False, query_seq_id=None,
                          field_list=None, blast_hit_rank=0,
                          seq_id_0_set_id=None,
                          batch_size=None):
        """Fetch the raw score hit and score lists from the database
        for blast (blast_hit_arr), symmetric blast
        (blast_hit_symmetric_arr), or NC (blast_hit_nc_arr).  Returns
        an iterator of rows: (seq_id_0, hit_list, <field_list>).  If
        field_list is not defined, it will contain the specified stype
        (e.g., bit_score).

        If query_seq_id is given, return results only for this query
        sequence id.  query_seq_id may also be a list of query seq_ids.

        field_list must contain only columns in the respective
        database table.
        """

        assert stype is not None, "'stype' must be specified."

        assert not (stype=='e_value' and symmetric), "Symmetric e_values are not stored in the database"

        if field_list is None:
            field_list = (stype,)
        elif type(field_list) == str:
            field_list = (field_list,)
        else:
            field_list = tuple( field_list)

        field_list = ('seq_id_0', 'hit_list') + field_list
        field_str = reduce( lambda a,b: a + ", " + b, field_list)            

        q_base = "SELECT %s" % field_str
        #q_from = ""
        #q_where = ""
        
        if stype in ('bit_score', 'e_value'):
            assert br_id is not None, "br_id must be specified for stype: %s" % stype
            
            if symmetric: blast_table = "blast_hit_symmetric_arr"
            else: blast_table = "blast_hit_arr"

            # Build the string using explicit arguments here to get
            # around psycopg2 quoting the table and column names when
            # passed as arguments later
            q_from = "\nFROM %s" % blast_table

            # Build the string using explicit arguments here to get
            # around psycopg2 quoting the table and column names when
            # passed as arguments later
            q_where = "\nWHERE br_id=%(br_id)s"

            if not symmetric:
                q_where += "\nAND hit_rank=%(blast_hit_rank)s"
        
        elif stype == 'nc_score':
            assert nc_id is not None, "nc_id must be specified for stype: %s" % stype
            
            q_from = "\nFROM blast_hit_nc_arr"
            q_where = "\nWHERE nc_id=%(nc_id)s"
            
        else:
            assert False, "Unknown stype: %s" % stype

        if seq_id_0_set_id is not None:
            q_from += "\nJOIN prot_seq_set_member AS ps0 ON (ps0.seq_id = seq_id_0)"
            q_where += "\nAND ps0.set_id=%(seq_id_0_set_id)s"

        if query_seq_id is not None:
            q_where += "\nAND seq_id_0 IN %(query_seq_id)s"
            # FIXME: other types?  Handle this another way?
            if isinstance(query_seq_id, (list,set)):  
                query_seq_id = tuple(query_seq_id)
            else:
                query_seq_id = (query_seq_id,)

        q = q_base + q_from + q_where

        #print q % locals()
        
        # query in 5000-row chunks
        if batch_size is None: batch_size=5000
        return self.dbw.fetchall_iter(q, locals(), batch_size=batch_size)


    def fetch_hits_direct( self, stype=None, br_id=None, nc_id=None,
                           field_list=None, self_hits=True,
                           query_seq_id=None, symmetric=False,
                           thresh=None, seq_id_0_set_id=None,
                           correct_e_value=False, blast_hit_rank=0,
                           set_id_filter=None):
        """Fetch all blast scores for a blast (blast_hit_arr, hit_rank
        0), symmetric (blast_hit_symmetric_arr) run, or NC
        (blast_hit_nc_arr) run.  Returns an iterator of (seq_id_0,
        seq_id_1, ...field list... ) tuples.

        Thresh may be used to select records with either
        'bit_score'>=thresh, 'nc_score'>=thresh, or 'e_value'<=thresh,
        as these are stored in sorted order in the database. If
        field_list is not specified, it will default to one column of
        stype.  If field_list is specified, stype must be in it.

        If query_seq_id is given, return results only for this query
        sequence id.  query_seq_id may also be a list of seq_ids.

        The returned sequences may be filtered by set_id.
        'set_id_0_set_id' returns hits only from a sequence of a set.
        'set_id_filter' returns hits only where both sequences are a
        member of a set.  The latter takes precidence.

        If correct_e_value is true, then e_value=-log10(e_value /
        dbsize + 1E-323).  Thresholds are applied *before* such
        transformation. The first item in field_list must be e_value"""

        if field_list is None: field_list = (stype,)

        # FIXME: use a database procedure for this filtering
        if thresh is not None:
            m = "If thresholding, stype must be in the field_list."
            assert field_list is not None and stype in field_list, m

            # field_list will be prefixed by hit_list, so add one to the index
            stype_ind = list(field_list).index(stype) + 1

            if stype == 'e_value':
                thresh_break_cond = eval("lambda a: a[%d] > %r" % (stype_ind, thresh))
            else:
                thresh_break_cond = eval("lambda a: a[%d] < %r" % (stype_ind, thresh))
        else:
            thresh_break_cond = None

        if set_id_filter is not None:
            filter_set = set(self.fetch_seq_set(set_id=set_id_filter))

            # perform the query using our set_id
            seq_id_0_set_id = set_id_filter

            # this filters the returned hits in python
            filter_set_cond = lambda a: a in filter_set
        else: filter_set_cond = None

        if 'e_value' in field_list and correct_e_value:
            dbsize = self.fetch_db_size( br_id=br_id)
            # field_list will be prefixed by hit_list, so add one to the index
            e_index = list(field_list).index('e_value') + 1
            
            tup_transform = eval("""lambda a: tuple(
            list(a[:%(e_index)d]) +
            [-math.log10(a[%(e_index)d] / %(dbsize)d + 1E-323)] +
            list(a[%(e_index_plus)d:]))""" % {
                'e_index':e_index, 'dbsize':dbsize, 'e_index_plus':e_index+1} )

        else:
            tup_transform=None

        row_iter = self.fetch_score_lists( br_id=br_id, nc_id=nc_id,
                                           stype=stype,
                                           symmetric=symmetric,
                                           query_seq_id=query_seq_id,
                                           field_list=field_list,
                                           blast_hit_rank=blast_hit_rank,
                                           seq_id_0_set_id=seq_id_0_set_id)
        for row in row_iter:
            seq_id_0 = row[0]

            #print row
            # row[1:] = (hit_list, bit_score, ...)  (for example)
            for hit_tup in itertools.izip( *row[1:] ):
                # hit_tup = (seq_id_1, bit_score, ...)

                # Don't return self hits?
                if not self_hits and seq_id_0 == hit_tup[0]:
                    continue

                # Filter hits by a set_id
                if filter_set_cond is not None and not filter_set_cond(hit_tup[0]):
                    continue

                # break when we pass the threshold in our score list
                if thresh_break_cond is not None and thresh_break_cond(hit_tup):
                    break

                # transform the tuple, e.g., for e_value correction
                if tup_transform is not None:
                    hit_tup = tup_transform(hit_tup)

                yield (seq_id_0,) + hit_tup
        return

    fetch_hits_iter = fetch_hits_direct


    def fetch_hits( self, stype=None, br_id=None, nc_id=None,
                    self_hits=True, query_seq_id=None,
                    symmetric=False, thresh=None,
                    seq_id_0_set_id=None, correct_e_value=False,
                    set_id_filter = None,
                    blast_hit_rank=0, score_fn=None):
        """Return a dictionary of similarity scores of the form
{seq_id_0: {seq_id_1: score}}.

This function defaults to symmetric scores.

If specified, score_fn should be a function to transform scores.

See fetch_hits_direct for parameter details."""
        
        # fetch from the pickle cache if possible
        if self.cacheq:
            retval = pickler.cachefn(pickledir=self.pickledir)
            if retval: return retval

        if score_fn is None:
            score_fn = lambda a: a

        row_iter = self.fetch_hits_direct( stype=stype, br_id=br_id,
                                  nc_id=nc_id, self_hits=self_hits,
                                  query_seq_id=query_seq_id,
                                  symmetric=symmetric, thresh=thresh,
                                  seq_id_0_set_id=seq_id_0_set_id,
                                  correct_e_value=correct_e_value,
                                  set_id_filter=set_id_filter,
                                  blast_hit_rank=blast_hit_rank)

        blasthits = {}
        seq_ids = {}  # keep only one copy of each key in memory
        for (seq_id_0, seq_id_1, score) in row_iter:
            seq_id_0 = int(seq_id_0)
            seq_id_1 = int(seq_id_1)
            if not seq_ids.has_key(seq_id_0): seq_ids[seq_id_0] = seq_id_0
            else: seq_id_0 = seq_ids[seq_id_0]
            if not seq_ids.has_key(seq_id_1): seq_ids[seq_id_1] = seq_id_1
            else: seq_id_1 = seq_ids[seq_id_1]
            
            if not blasthits.has_key( seq_id_0):
                blasthits[seq_id_0] = {}

            # If symmetric, check the other direction, and use that
            # value if it exists already
            if symmetric and blasthits.get(seq_id_1,{}).has_key(seq_id_0):
                blasthits[seq_id_0][seq_id_1] = blasthits[seq_id_1][seq_id_0]
            else:
                blasthits[seq_id_0][seq_id_1] = score_fn(score)

        if self.cacheq: pickler.cachefn( retval = blasthits,
                                         pickledir=self.pickledir)
        return blasthits


    def fetch_hits_dictarray(self, stype=None, br_id=None, nc_id=None,
                             symmetric=False, query_seq_id=None,
                             blast_hit_rank=0, seq_id_0_set_id=None,
                             set_id_filter=None, seq_set=None,
                             self_hits=True):
        """Fetch a dict of numpy arrays, representing the raw score
        lists in the database).  Returns {'seq_id': (hit_list,
        score_list), ...}.

        hit_list and score list are sorted in order of seq_id in hit_list

        'query_seq_id', 'seq_id_0_set_id', 'set_id_filter', and
        'seq_set' may be used to filter the query results returned:
        
          * query_seq_id and seq_id_0_set_id affect only the source
            query, specifying a single query seq_id, or a database
            set_id, to fetch all queries in that set.

          * set_id_filter and seq_set affect both the query and target
            sequences; namely, both query and target must exist within
            the set specified.  set_id_filter must be a database set_id,
            and seq_set a Python set() of seq_ids.

          * If more than one of the above options is specified, their
            intersection will be used in filtering results returned.
            Explicitly using filters on the query sequence will occur
            in the database, and thus be faster. 
        
        See fetch_score_lists for other parameter descriptions.
        """

        if self.cacheq:
            retval = pickler.cachefn(pickledir=self.pickledir)
            if retval: return retval

        hit_filter = False
        hit_set = None
        if seq_set is not None:
            hit_set = seq_set if type(seq_set) is set else set(seq_set)
            hit_filter = True

        if set_id_filter is not None:
            set_id_seqs = set(self.fetch_seq_set(set_id=set_id_filter))            
            
            if hit_set is None:
                hit_filter = True
                hit_set = set_id_seqs
            else:
                # it is faster to do single queries if very few will be
                # performed.  otherwise, one big transfer is likely faster
                #if len(hit_set) < len(set_id_seqs):
                #    do_single_queries = True
                
                hit_set.intersection_update( set_id_seqs)

        if hit_filter:
            hit_set_arr = numpy.array( list(hit_set))

        hits = {}

        qrate = rate.rate()

        for row in self.fetch_score_lists(
            stype=stype, br_id=br_id, nc_id=nc_id,
            symmetric=symmetric, query_seq_id=query_seq_id,
            blast_hit_rank=blast_hit_rank, seq_id_0_set_id=seq_id_0_set_id):

            assert row is not None, "Should never happen?"

            (seq_id_0, hit_list, score_list) = row

            assert hit_list is not None, "seq_id %s: Hit list is None.  This is almost certainly an error in the database." % (
                seq_id_0,)
                
            assert score_list is not None, """seq_id %s: Score list is None.  This may be an error in the database,
or, more likely, be because e_values are not stored for symmetric
scores.""" % (seq_id_0,)

            assert len(hit_list) == len(score_list), """seq_id %s: Hit and score lists are of different lengths.
len(hit_list): %s, len(score_list)=%s""" % (
                seq_id_0, len(hit_list), len(score_list))

            qrate.increment()
            if qrate.count % 1000 == 0: print qrate            

            #if hit_filter and seq_id_0 not in hit_set: continue

            if True:
                # Using arrays from the start is rather (30%) faster
                hit_arr = numpy.array(hit_list, dtype=numpy.int32)
                score_arr = numpy.array(score_list, dtype=numpy.float64)

                # remove elements not in the sequence set
                if hit_filter:
                    indicies = numpy.in1d(hit_arr, hit_set_arr, assume_unique=True)

                    hit_arr = hit_arr[ indicies]
                    score_arr = score_arr[ indicies]

                # sort by seq_id in hit_arr
                # this will fail for symmetric e_value, since it's not stored in the database
                order = numpy.argsort( hit_arr)
                hit_arr = hit_arr[ order]
                score_arr = score_arr[ order]

                # remove self hit (this is faster after sorting)
                # FIXME: combine with hit_filter?
                if not self_hits:
                    ind = numpy.searchsorted( hit_arr, seq_id_0)
                    if ind < len(hit_arr) and hit_arr[ind] == seq_id_0:
                        # eliminate this element by shifting left
                        hit_arr[ind:-1] = hit_arr[ind+1:]
                        hit_arr.resize( [len(hit_arr)-1], refcheck=False )
                    
                        score_arr[ind:-1] = score_arr[ind+1:]
                        score_arr.resize( [len(score_arr)-1])
            
                hits[seq_id_0] = hit_arr, score_arr

            else:
                hit_zip = zip(hit_list, score_list)

                # remove elements not in the sequence set
                if hit_filter:
                    hit_zip = [ a for a in hit_zip if a[0] in hit_set]

                if not self_hits:
                    hit_zip = [ a for a in hit_zip if a[0] != seq_id_0]

                # sort by seq_id in hit_list, unzip
                hit_zip.sort()
                hit_list, score_list = zip(*hit_zip)
            
                hit_list = numpy.array(hit_list, dtype=numpy.int32)
                score_list = numpy.array(score_list)
            
                hits[seq_id_0] = hit_list, score_list
            

        if self.cacheq: pickler.cachefn( retval = hits,
                                         pickledir=self.pickledir)
        return hits



    def fetch_blast_nc_hits_iter( self, br_id, nc_id, stype='bit_score',
                                  thresh_s=None, thresh_nc=None, symmetric=True,
                                  self_hits=True):    
        """Return an iterator of (seq_id_0, seq_id_1, blast score,
nc_score) tuples.  The blast_hit_symmetric and blast_hit_nc tables are
queried.  Either 'bit_score' or 'e_value' may be specified for the blast score
type, stype.  Thresholds for blast or NC may be specified with the
thresh_s and thresh_nc parameters.
                      
This function streams from the server and must run to completion
before other queries may be executed."""

        assert False, "Not yet updated for postgres"

        q_base = """SELECT seq_id_0, seq_id_1, %s, nc_score """

        if symmetric: q_base += "FROM blast_hit_symmetric "
        else: q_base += "FROM blast_hit "

        q_base += """JOIN blast_hit_nc using (seq_id_0, seq_id_1)
        WHERE br_id = '%d'
        AND nc_id = '%d' """

        if not symmetric: q_base += " AND hit_rank=0"

        if stype not in ('bit_score', 'e_value'):
            print "fetch_blast_nc_hits_iter: unknown stype:", stype

        q = q_base % (stype, br_id, nc_id)

        if thresh_s:
            if stype=='bit_score': q += """AND bit_score >= '%s'""" % thresh_s
            elif stype=='e_value': q += """AND e_value <= '%s'""" % thresh_s
        if thresh_nc: q += """ AND nc_score >= '%s'""" % thresh_nc

        if not self_hits:
            q += " AND seq_id_0 != seq_id_1"
            
        self.dbw.ping()
        self.dbw.sscurs.execute(q)

        arraysize = 100000
        while True:
            results = self.dbw.sscurs.fetchmany(arraysize)
            if not results: break
            for (n1,n2,b_score,nc_score) in results:
                # convert node keys to ints
                yield (int(n1),int(n2), b_score, nc_score)


    def fetch_score_histogram_sql(self, br_id=None, nc_id=None, stype=None,
                              self_hits=False, symmetric=True,
                              set_id_filter=None):
        """Return an histogram of scores.  Scores are binned by taking
        the ceiling to the nearest multiple of 25 for bit_score, and
        0.01 for nc_score, and, for e_value, by nearest integer of
        -log10(e_value / db_size + 1E-200).  Return is formatted as a
        list of length 'nbins', containing tuples: ( min_score,
        max_score, count)."""

        # This should work, but could be slow
        #assert False, "Not updated for postgres"

        q_params = {}

        if stype in ('bit_score', 'e_value') and br_id is not None:
            q_params['where_cond'] = "br_id='%d'" % br_id
            if symmetric:
                q_params['table'] = 'blast_hit_symmetric'
            else:
                q_params['table'] = 'blast_hit'
                q_params['where_cond'] += "\n AND hit_rank=0"

            if stype=='bit_score':
                q_params['round_agg'] = "CEILING(bit_score/25)*25"
                q_params['stype'] = stype
            else:
                dbsize = self.fetch_db_size(br_id = br_id)
                q_params['round_agg'] = "CEILING( -log10(e_value / %d + 1E-200))" % dbsize
                q_params['stype'] = "-log10(e_value / %d + 1E-200)" % dbsize
            
        elif stype == 'nc_score' and nc_id is not None:
            q_params['where_cond'] = "nc_id='%d'" % nc_id
            q_params['table'] = 'blast_hit_nc'
            q_params['round_agg'] = "CEILING(nc_score*100)/100"
            q_params['stype'] = stype
            
        else:
            assert False, "Unknown stype, br_id, or nc_id: %s, %s, %s" % (
                stype, br_id, nc_id)

        q_params['join'] = ""
        q_params['join_cond'] = ""
        if set_id_filter is not None:
            q_params['join'] = """ JOIN prot_seq_set_member as ps0 ON (ps0.seq_id = seq_id_0)
            JOIN prot_seq_set_member as ps1 ON (ps1.seq_id = seq_id_1)"""
            q_params['join_cond'] = """AND ps0.set_id=%(set_id)d
            AND ps1.set_id=%(set_id)d""" % {'set_id': set_id_filter}


        if not self_hits:
            q_params['where_cond'] += "\n AND seq_id_0 != seq_id_1"

        q = """SELECT MIN(%(stype)s) AS min_score, MAX(%(stype)s), COUNT(*)
        FROM %(table)s
        %(join)s
        WHERE %(where_cond)s
        %(join_cond)s
        GROUP BY %(round_agg)s
        ORDER BY min_score"""

        dbret = self.dbw.fetchall( q % q_params)
        return dbret


    def fetch_score_histogram_slower(self, stype=None, br_id=None, nc_id=None,
                              self_hits=False, symmetric=False,
                              set_id_filter=None):

        hit_iter = self.fetch_hits_direct(stype=stype, br_id=br_id, nc_id=nc_id,
                                          self_hits=self_hits,
                                          symmetric=symmetric,
                                          set_id_filter=set_id_filter)

        # { ceiling: [min_score, max_score, cnt]}
        histogram = {}

        # we could use numpy.histogram here, but it wouldn't deal with
        # self hits, and would only work well when we know the best
        # bin boundaries (like with NC)
        for (seq_id_0, seq_id_1, score) in hit_iter:
            key = math.ceil( score * 100) / 100

            if not key in histogram:
                histogram[key] = (score, score, 1)
            else:
                (cur_min, cur_max, cnt) = histogram[key]
                histogram[key] = ( min(cur_min, score),
                                   max(cur_max, score),
                                   cnt + 1)
        return histogram.values()


    #FIXME: We probably need for this to return a more sophisticated
    #histogram class that allows addition and subtraction of
    #histograms.  It would store the grouping function as well
    def fetch_score_histogram(self, stype=None, br_id=None, nc_id=None,
                              self_hits=False, symmetric=False,
                              set_id_filter=None, # database-level query filter -- use for speed
                              query_seq_id=None,  # database-level query filter -- use for speed
                              seq_set=None,       # only count hits within this set
                              return_key_dict=False # return the internal dictionary of bins instead of just the values
                              ):
        if self.cacheq:
            retval = pickler.cachefn(pickledir=self.pickledir,
                                     args = "%s_%s_%s_%s_%s_%s_%s_%s_%s" % (
                stype, br_id, nc_id,
                self_hits, symmetric,
                set_id_filter,
                hash(tuple(sorted(query_seq_id))) if hasattr(query_seq_id,'__iter__') else query_seq_id,
                hash(tuple(sorted(seq_set))) if hasattr(seq_set,'__iter__') else seq_set,
                return_key_dict
                ))

            if retval: return retval

        hit_dict = self.fetch_hits_dictarray(stype = stype, br_id = br_id, nc_id = nc_id,
                                             symmetric = symmetric, self_hits = self_hits,
                                             seq_id_0_set_id = set_id_filter,
                                             set_id_filter = set_id_filter,
                                             query_seq_id = query_seq_id,
                                             seq_set = seq_set)

        # { ceiling: [min_score, max_score, cnt]}
        histogram = {}

        if stype == 'nc_score':
            group_fn = lambda a: math.ceil( a * 100) / 100
        elif stype == 'bit_score':
            group_fn = lambda a: math.ceil( a / 25) * 25
        elif stype == 'e_value':
            dbsize = self.fetch_db_size(br_id = br_id)
            group_fn = eval("lambda a: math.ceil( -math.log10(a / %d + 1E-200) / 4) * 4" % dbsize)
        else:
            assert False, "Unknown stype: %s" % stype

        # FIXME: this could be done much faster by using numpy array functions, or weave
        
        # we could use numpy.histogram here, but it wouldn't deal with
        # self hits, and would only work well when we know the best
        # bin boundaries (like with NC)
        #for (seq_id_0, hit_list, score_list) in hit_iter:
        for (seq_id_0, (hit_list, score_list)) in hit_dict.iteritems():
            for (seq_id_1, score) in itertools.izip(hit_list, score_list):
                # skip self hits
                if not self_hits and seq_id_0 == seq_id_1: continue
                
                key = group_fn(score)

                if not key in histogram:
                    histogram[key] = (score, score, 1)
                else:
                    (cur_min, cur_max, cnt) = histogram[key]
                    histogram[key] = ( min(cur_min, score),
                                       max(cur_max, score),
                                       cnt + 1)
        if return_key_dict:
            ret = histogram
        else:
            ret = sorted(histogram.values())
        if self.cacheq: pickler.cachefn( retval = ret,
                                         pickledir=self.pickledir,
                                         args = "%s_%s_%s_%s_%s_%s_%s_%s_%s" % (
            stype, br_id, nc_id,
            self_hits, symmetric,
            set_id_filter,
            hash(tuple(sorted(query_seq_id))) if hasattr(query_seq_id,'__iter__') else query_seq_id,
            hash(tuple(sorted(seq_set))) if hasattr(seq_set,'__iter__') else seq_set,
            return_key_dict
            ))
        return ret
    
    def insert_blast_hit_symmetric( self, br_id, seq_id_0, record_list):
        """Insert a symmetric blast row. 
        record_list = [ (seq_id_1, bit_score, e_value), ...]
        """

        i = """INSERT INTO blast_hit_symmetric_arr
        (br_id, seq_id_0, hit_list, bit_score, e_value)
        VALUES
        (%(br_id)d, %(seq_id_0)d,
        ARRAY[ %(hit_list_str)s ],
        ARRAY[ %(bit_score_str)s ],
        ARRAY[ %(e_value_str)s ])"""

        # sort by descending bit_score
        record_list.sort( key=lambda a: a[1], reverse=True)

        hit_list, bit_list, e_list = zip( *record_list)

        hit_list_str = reduce( lambda a,b: "%s, %s" % (a,b), hit_list)
        bit_score_str = reduce( lambda a,b: "%s, %s" % (a,b), bit_list)
        e_value_str = reduce( lambda a,b: "%s, %s" % (a,b), e_list)
        
        self.dbw.execute( i % locals())
        self.dbw.commit()
        
        return

    def insert_blast_hit_symmetric_dictarray( self,
        br_id, score_dict):
        
        """Insert a dictionary array into blast_hit_symmetric_arr with
        a given br_id.

        The dictionary should be of the form listed in
        fetch_hits_dictarray.  The order (hit_list, bit_score,
        e_value) is required.  E_value may be omitted, with a list of
        (hit_list, bit_score)."""
        
        i_full = """INSERT INTO blast_hit_symmetric_arr
        (br_id, seq_id_0, hit_list, bit_score, e_value)
        VALUES
        ( %(br_id)s, %(seq_id_0)s,
        %(hit_list)s, %(bit_score)s, %(e_value)s)"""

        i_short = """INSERT INTO blast_hit_symmetric_arr
        (br_id, seq_id_0, hit_list, bit_score, e_value)
        VALUES
        ( %(br_id)s, %(seq_id_0)s,
        %(hit_list)s, %(bit_score)s, %(e_value)s)"""

        # check whether we have e_value
        if len(score_dict[score_dict.keys()[0]]) == 3:
            i = i_full
        elif len(score_dict[score_dict.keys()[0]]) == 2:
            i = i_short
        else:
            assert False, "score_dict contains tuples of incorrect length."


        for seq_id_0, tuparray in score_dict.iteritems():

            # order by *descending* bit_score
            order = numpy.argsort( tuparray[1])[::-1]
            hit_list = tuparray[0][ order]
            bit_score = tuparray[1][ order]

            if len(tuparray) == 3:
                e_value = tuparray[2][order]
            else:
                e_value = None
                            
            # psycopg doesn't handle adaptation of numpy arrays
            hit_list = hit_list.tolist()
            bit_score = bit_score.tolist()
            e_value = e_value.tolist() if e_value is not None else None
            
            self.dbw.execute( i, locals())
            
        self.dbw.commit()

        return

        
    def insert_nc_run( self, br_id, e_thresh, bit_thresh, nc_thresh,
                       blast_hit_limit, smin, smin_factor,
                       use_symmetric, score_type, self_hits):
        "Insert an NC run"

        i = """INSERT INTO nc_run
        (br_id, date, e_thresh, bit_thresh, nc_thresh, blast_hit_limit,
        smin, smin_factor,
        use_symmetric, score_type, self_hits)
        VALUES
        (%(br_id)s, NOW(), %(e_thresh)s, %(bit_thresh)s, %(nc_thresh)s,
        %(blast_hit_limit)s,
        %(smin)s, %(smin_factor)s, %(use_symmetric)s, %(score_type)s,
        %(self_hits)s)
        RETURNING nc_id"""

        self.dbw.execute( i, locals())
        nc_id = self.dbw.fetchsingle()

        self.dbw.commit()
        return nc_id

    
    #def insert_blast_hit_nc_dict(self, nc_id, seq_id_0, n_0, record_dict):
    #    "Insert an NC score row"
    #
    #    i = """INSERT INTO blast_hit_nc_arr
    #    (nc_id, seq_id_0, num_match_seq_0, hit_list, nc_score, num_match_seq_1, num_match_seq_both)
    #    VALUES
    #    (%(nc_id)d, %(seq_id_0)d, %(n_0)d,
    #    ARRAY[ %(hit_list_str)s ],
    #    ARRAY[ %(nc_score_str)s ],
    #    ARRAY[ %(num_match_seq_1_str)s ],
    #    ARRAY[ %(num_match_seq_both_str)s ])"""
    #
    #    # sort the hit and record arrays by descending nc_score
    #    hit_scores = sorted( record_dict.iteritems(),
    #                         key=lambda a: a[1][0],
    #                         reverse=True)
    #    # [seq_id_0, ...], [ (nc_score, num_match_seq_1, num_match_seq_both)
    #    hit_list, record_list = zip( *hit_scores)
    #    nc_score, num_match_seq_1, num_match_seq_both = zip( *record_list)
    #
    #    hit_list_str = reduce( lambda a,b: "%s, %s" % (a,b), hit_list)
    #    nc_score_str = reduce( lambda a,b: "%s, %s" % (a,b), nc_score)
    #    num_match_seq_1_str = reduce( lambda a,b: "%s, %s" % (a,b), num_match_seq_1)
    #    num_match_seq_both_str = reduce( lambda a,b: "%s, %s" % (a,b), num_match_seq_both)
    #
    #    self.dbw.execute(i % locals())
    #    self.dbw.commit()
    #    
    #    return

    def insert_blast_hit_nc(self, nc_id, seq_id_0, n_0, record_list):
        """Insert an NC score row.
        record_list = [ (seq_id_1, n_1, n_01, nc_score), ...]
        """

        # don't insert empty rows
        if len(record_list) < 1:
            return

        i = """INSERT INTO blast_hit_nc_arr
        (nc_id, seq_id_0, num_match_seq_0, hit_list, nc_score, num_match_seq_1, num_match_seq_both)
        VALUES
        (%(nc_id)d, %(seq_id_0)d, %(n_0)d,
        ARRAY[ %(hit_list_str)s ],
        ARRAY[ %(nc_score_str)s ],
        ARRAY[ %(num_match_seq_1_str)s ],
        ARRAY[ %(num_match_seq_both_str)s ])"""

        # sort by nc_score
        record_list.sort( key=lambda a: a[3], reverse=True)

        hit_list, num_match_seq_1, num_match_seq_both, nc_score = zip( *record_list)
        
        hit_list_str = reduce( lambda a,b: "%s, %s" % (a,b), hit_list)
        num_match_seq_1_str = reduce( lambda a,b: "%s, %s" % (a,b), num_match_seq_1)
        num_match_seq_both_str = reduce( lambda a,b: "%s, %s" % (a,b), num_match_seq_both)
        nc_score_str = reduce( lambda a,b: "%s, %s" % (a,b), nc_score)

        self.dbw.execute(i % locals())
        self.dbw.commit()
        
        return

    def insert_blast_hit( self, br_id, hit_rank, seq_id_0, record_dict,
                          commit=True):
        "Insert a blast hit row"
        
        d = """DELETE FROM blast_hit_arr where br_id=%(br_id)s
        and hit_rank=%(hit_rank)s and seq_id_0=%(seq_id_0)s
        RETURNING br_id, hit_rank, seq_id_0, array_upper(hit_list,1)"""

        #print br_id, hit_rank, seq_id_0
        self.dbw.fetchall(d, locals())

        i = """INSERT INTO blast_hit_arr
        (br_id, hit_rank, seq_id_0, hit_list, bit_score, e_value,
        seq_start_0, seq_end_0, seq_start_1, seq_end_1,
        identities, gaps, positives, low_complexity, coverage)
        VALUES
        (%(br_id)d, %(hit_rank)d, %(seq_id_0)d,
        ARRAY[ %(hit_list_str)s ],
        ARRAY[ %(bit_score_str)s ],
        ARRAY[ %(e_value_str)s ],
        ARRAY[ %(seq_start_0_str)s ],
        ARRAY[ %(seq_end_1_str)s ],
        ARRAY[ %(seq_start_1_str)s ],
        ARRAY[ %(seq_end_1_str)s ],
        ARRAY[ %(identities_str)s ],
        ARRAY[ %(gaps_str)s ],
        ARRAY[ %(positives_str)s ],
        ARRAY[ %(low_complexity_str)s ],
        ARRAY[ %(coverage_str)s ])"""

        # sort the hit and record arrays by descending bit_score
        hit_scores = sorted(record_dict.iteritems(),
                            key=lambda a: a[1][0],
                            reverse=True)
        # [seq_id_0, ...], [(bit_score, e_value, seq_start_0,
        #                    seq_end_0, seq_start_1, seq_end_1,
        #                    identities, gaps, positives, low_complexity,
        #                    coverage)
        hit_list, record_list = zip( *hit_scores)

        (bit_score, e_value, seq_start_0, seq_end_0, seq_start_1,
         seq_end_1, identities, gaps, positives, low_complexity,
         coverage) = zip( *record_list)
        
        hit_list_str = reduce( lambda a,b: "%s, %s" % (a,b), hit_list)
        bit_score_str = reduce( lambda a,b: "%s, %s" % (a,b), bit_score)
        e_value_str = reduce( lambda a,b: "%s, %s" % (a,b), e_value)
        seq_start_0_str = reduce( lambda a,b: "%s, %s" % (a,b), seq_start_0)
        seq_end_0_str = reduce( lambda a,b: "%s, %s" % (a,b), seq_end_0)
        seq_start_1_str = reduce( lambda a,b: "%s, %s" % (a,b), seq_start_1)
        seq_end_1_str = reduce( lambda a,b: "%s, %s" % (a,b), seq_end_1)
        identities_str = reduce( lambda a,b: "%s, %s" % (a,b), identities)
        gaps_str = reduce( lambda a,b: "%s, %s" % (a,b), gaps)
        positives_str = reduce( lambda a,b: "%s, %s" % (a,b), positives)
        low_complexity_str = reduce( lambda a,b: "%s, %s" % (a,b), low_complexity)
        coverage_str = reduce( lambda a,b: "%s, %s" % (a,b), coverage)
        
        self.dbw.execute( i % locals())
        if commit: self.dbw.commit()

        return

    def insert_blast_hit_xtra( self, seq_id_0, seq_id_1, br_id, hit_rank,
                               seq_start_0, seq_end_0, seq_start_1, seq_end_1,
                               identities, gaps, positives, low_complexity,
                               coverage): #, query_seq, match_seq, sbjct_seq):
        "Insert a blast hit's extra information"

        assert False, "Use insert_blast_hit instead"
        
        i = """INSERT INTO blast_hit_xtra
        (seq_id_0, seq_id_1, br_id, hit_rank,
        seq_start_0, seq_end_0, seq_start_1, seq_end_1,
        identities, gaps, positives, low_complexity,
        coverage)
        VALUES
        ('%d', '%d', '%d', '%d',
        '%d', '%d', '%d', '%d',
        '%d', '%d', '%d', '%d',
        '%s')"""
        
        inputtup = ( seq_id_0, seq_id_1, br_id, hit_rank,
                     seq_start_0, seq_end_0, seq_start_1, seq_end_1,
                     identities, gaps, positives, low_complexity,
                     coverage) #, query_seq, match_seq, sbjct_seq)
        self.dbw.execute( i % inputtup)
        return

    def populate_blast_hit_symmetric_sql( self, br_id):
        print "Populating blast_hit_symmetric (in sql)"
        
        assert False, "Not yet compatible with postgres.  Use populate_blast_hit_symmetric_python"


    def populate_blast_hit_symmetric_python( self, br_id):
        print "Populating blast_hit_symmetric (in python)"

        print "warning: this can take a lot of memory.  The full distance matrix is currently stored"
        all_scores = self.fetch_hits_direct( stype='bit_score',br_id=br_id,
                                             field_list=('bit_score','e_value'),
                                             symmetric=False)

        # store each node only once, rather than a new reference for
        # each edge endpoint
        vertmap = {}

        # build dictionary of top hits
        # [ seq0][seq1] = (bit,eval)
        print "  calculating..."
        score_dict = {}
        for (seq_id_0, seq_id_1, bit_score, e_value) in all_scores:
            try:
                seq_id_0 = vertmap[seq_id_0]
            except:
                vertmap[seq_id_0] = seq_id_0
                score_dict[seq_id_0] = {}

            try:
                seq_id_1 = vertmap[seq_id_1]
            except:
                vertmap[seq_id_1] = seq_id_1
                score_dict[seq_id_1] = {}

            # retain the greatest bit score
            # we always store both halves of the matrix, so this can check only one
            if (not seq_id_1 in score_dict[seq_id_0] or
                bit_score > score_dict[seq_id_0][seq_id_1][0]):
                score_dict[seq_id_0][seq_id_1] = (bit_score, e_value)
                score_dict[seq_id_1][seq_id_0] = (bit_score, e_value)

        # insert all dictionary items
        print "  inserting..."
        for seq_id_0 in score_dict.keys():
            record_list = [ (seq_id_1, bit_score, e_value)
                            for (seq_id_1, (bit_score, e_value))
                            in score_dict[seq_id_0].iteritems()]
            self.insert_blast_hit_symmetric(
                br_id, seq_id_0, record_list)

        return

    populate_blast_hit_symmetric = populate_blast_hit_symmetric_python


    def populate_blast_hit_symmetric_scipy(self, br_id):
        from scipy import sparse
        
        query_set = self.fetch_seq_set(
            self.fetch_blast_info_d( br_id)['query_set_id'])
        
        score_iter = self.fetch_score_lists( stype='bit_score',
                                             br_id = br_id,
                                             field_list = ('bit_score',),
                                             symmetric=False)

        nseqs = len(query_set)
        A_bit = sparse.lil_matrix((nseqs, nseqs), dtype=numpy.float64)

        seq_map = dict( ((seq,i) for i,seq in enumerate( query_set)) )

        for (seq_id_0, hit_list, score_list) in score_iter:
            #if not seq_id_0 in seq_map: continue
            try: i = seq_map[seq_id_0]
            except: continue
            
            for seq_id_1, bit_score in itertools.izip(hit_list, score_list):
                #if not seq_id_1 in seq_map: continue
                try: j = seq_map[seq_id_1]
                except: continue
                
                #if not (i,j) in A_bit or bit_score > A_bit[i,j]:
                if bit_score > A_bit[i,j]:
                    A_bit[i,j] = bit_score
                    A_bit[j,i] = bit_score
        return A_bit
    
    def populate_blast_hit_symmetric_python_bit(self, br_id):
        score_iter = self.fetch_score_lists( stype='bit_score',
                                             br_id = br_id,
                                             field_list = ('bit_score',),
                                             symmetric=False)
        score_dict = {}
        vertmap = {}

        for (seq_id_0, hit_list, bit_list) in score_iter:
            try: seq_id_0 = vertmap[seq_id_0]
            except:
                vertmap[seq_id_0] = seq_id_0
                score_dict[seq_id_0] = {}

            sd0 = score_dict[seq_id_0]
            
            for seq_id_1, bit_score in itertools.izip(hit_list, bit_list):
            #for i,seq_id_1 in enumerate(hit_list):
            #for seq_id_1, bit_score in zip(hit_list, bit_list):
                try: seq_id_1 = vertmap[seq_id_1]
                except:
                    vertmap[seq_id_1] = seq_id_1
                    score_dict[seq_id_1] = {}

                if seq_id_1 in sd0 and bit_score < sd0[seq_id_1]:
                    continue

                sd0[seq_id_1] = bit_score
                score_dict[seq_id_1][seq_id_0] = bit_score

        return score_dict
    
if __name__=="__main__":
    bq = blastq(cacheq=False)
    

    #import cProfile
    #cProfile.runctx("fetchall()", globals(), locals())

    #import cProfile
    #cProfile.runctx("A = bq.populate_blast_hit_symmetric_scipy(75)", globals(), locals())
    #cProfile.runctx("A = bq.populate_blast_hit_symmetric_python(75)", globals(), locals())
    #cProfile.runctx("A = bq.populate_blast_hit_symmetric_python_bit(104)", globals(), locals())
