#!/usr/bin/env python
# Jacob Joseph
# June 20, 2007

import array, os
import blastq
from JJutil import pickler
import alignment_coverage
# Moved in ipython 0.11
#from IPython.Shell import IPShellEmbed

class familyq(blastq.blastq):
    "Class to abstract family queries.  Inherits blastq"

    family_sets = None
    family_members = None

    def __init__( self, 
                  family_set_id=None,
                  family_set_name=None,
                  **kwargs):
        blastq.blastq.__init__(self, **kwargs)

        if family_set_id is not None:
            self.fam_set_id = family_set_id
        elif family_set_name is not None:
            self.fam_set_id = self.fetch_family_set_info(name=family_set_name)['fam_set_id']

        assert self.fam_set_id is not None, "One of family_set_id or family_set_name must be specified"

        (self.family_sets, self.family_members) = self.fetch_families()

    def fetch_family_set_info(self, name=None):
        q = """SELECT fam_set_id, name, description
        FROM family_set"""

        if name is not None:
            q += "\nWHERE name=%(name)s"            
        else:
            fam_set_id = self.fam_set_id
            q += "\nWHERE fam_set_id=%(fam_set_id)s"

        return self.dbw.fetchone_d(q, locals())

    def fetch_family_size( self, families):
        "Return num_sequences in a set of families"
        k = 0
        for family in families:
            k += self.fetch_family_size_family( family)
        return k

    def fetch_family_size_family( self, family_abbrev):
        "Return num_sequences in a particular family"
        if self.cacheq:
            retval = pickler.cachefn(pickledir=self.pickledir,
                                     args = "%d_%s" % (self.fam_set_id, family_abbrev))
            if retval: return retval
            
        q =  """SELECT count(*)
        FROM family_member
        JOIN family USING (family_id)
        WHERE fam_set_id=%(fam_set_id)s
        AND abbrev=%(family_abbrev)s"""
        
        fam_set_id = self.fam_set_id
        retval = self.dbw.fetchsingle(q, locals())

        if self.cacheq: pickler.cachefn( pickledir=self.pickledir,
                                         retval = retval,
                                         args = "%d_%s" % (set_id, family_abbrev))
        return retval

    def fetch_seq_family(self, seq_id):
        """Return family_abbrev (or None) for a particular sequence."""
        if self.family_members.has_key(seq_id):
            return self.family_members[seq_id]
        return None

    def fetch_family_seqs(self, family_abbrev):
        """Return the set of seq_ids (or None) for a particular family abbreviation."""
        if self.family_sets.has_key(family_abbrev):
            return self.family_sets[family_abbrev]
        return None

    def fetch_families( self):
        """Return two dictionaries:
({ 'family_abbrev':set(seq_id,...), }, {seq_id: 'family_abbrev', })"""
        if self.cacheq:
            retval = pickler.cachefn(pickledir = self.pickledir,
                                     args = "%d" % self.fam_set_id)
            if retval: return retval

        q = """SELECT seq_id, abbrev, family_id
        FROM family_member
        JOIN family USING (family_id)
        WHERE fam_set_id=%(fam_set_id)s"""

        fam_set_id = self.fam_set_id
        family_sets = {}
        family_members = {}
        for (seq_id, family_abbrev, family_id) in self.dbw.fetchall(q, locals()):
            if not family_sets.has_key(family_abbrev):
                family_sets[family_abbrev] = set()
            family_sets[family_abbrev].add(seq_id)
            family_members[seq_id] = family_abbrev

        retval = (family_sets, family_members)
        if self.cacheq: pickler.cachefn( pickledir = self.pickledir,
                                         retval = retval,
                                         args = "%d" % self.fam_set_id)
        return retval

    def fetch_family_coverage( self, family_abbrev, br_id, hit_rank=0,
                               correct_e_value=True):
        """Return iterator of (seq_id_0, seq_id_1, hit_rank, bit_score, e_value,
coverage, seq_start_0, seq_end_0, seq_start_1, seq_end_1, ff_boolean)
tuples for a single family.  These scores are not symmetric.  Results
are sorted by (seq_id_0, hit_rank). If hit_rank should be an integer or 'ALL'."""

        assert False, "Not yet updated for new family set schema"

        # correct e_value.  Add a very small value to handle e_value==0.
        # min(e_value / dbsize)= 4.52E-185 for e_value > 0
        if correct_e_value:
            dbsize = self.fetch_db_size( br_id)
            e_value = "-log10(bs.e_value / '%d' + 1E-200)" % dbsize
        else: e_value = 'bs.e_value'

        if type(hit_rank) is int: hit_rank_cond = "AND hit_rank='%d'" % hit_rank
        elif hit_rank=='ALL': hit_rank_cond = ""
        else: assert False, "Unknown hit_rank: %s" % hit_rank
        
        q = """
-- FO, FF pairs
(SELECT seq_id_0, seq_id_1, hit_rank, bs.bit_score, %(e_value)s, coverage,
seq_start_0, seq_end_0, seq_start_1, seq_end_1,
fm0.family_id=fm1.family_id
FROM family_member as fm0
JOIN family USING (family_id)
JOIN blast_hit_xtra as bx ON (fm0.seq_id=seq_id_0)
JOIN blast_hit_symmetric as bs USING (br_id, seq_id_0, seq_id_1)
LEFT JOIN family_member as fm1 ON (fm1.seq_id=seq_id_1)
WHERE family_abbrev='%(family_abbrev)s'
AND set_id='%(set_id)d'
AND br_id='%(br_id)d'
%(hit_rank_cond)s)
UNION ALL
-- OF pairs
(SELECT seq_id_0, seq_id_1, hit_rank, bs.bit_score, %(e_value)s, coverage,
seq_start_0, seq_end_0, seq_start_1, seq_end_1,
fm0.family_id=fm1.family_id
FROM family_member as fm1
JOIN family USING (family_id)
JOIN blast_hit_xtra AS bx ON (fm1.seq_id=seq_id_1)
JOIN blast_hit_symmetric AS bs USING (br_id, seq_id_0, seq_id_1)
LEFT JOIN family_member as fm0 ON (fm0.seq_id=seq_id_0)
WHERE family_abbrev='%(family_abbrev)s'
AND set_id='%(set_id)d'
AND (fm0.family_id is NULL or fm0.family_id != fm1.family_id)
AND br_id='%(br_id)d'
%(hit_rank_cond)s)
ORDER BY seq_id_0, hit_rank
"""
        self.dbw.ssexecute(q % {'e_value': e_value, 'family_abbrev': family_abbrev,
                                'br_id': br_id, 'hit_rank_cond': hit_rank_cond,
                                'set_id': self.fam_set_id } )

        arraysize = 100000
        while True:
            results = self.dbw.sscurs.fetchmany(arraysize)
            if not results: break
            for tup in results:
                yield tup

        
    def fetch_family_coverage_combined( self, family_abbrev, br_id, correct_e_value=True):
        """Return a list of the same format as fetch_family_coverage,
        combined coverage scores.  Not symmetric."""
        
        if self.cacheq:
            retval = pickler.cachefn(
                pickledir=self.pickledir,
                args = "%d_%s_%d_%s" % (self.fam_set_id, family_abbrev, br_id, correct_e_value))
            if retval: return retval

        hits = []
        
        # assume dbhits is ordered by seq_id_0, hit_rank
        dbhits = self.fetch_family_coverage( family_abbrev, br_id, hit_rank='ALL',
                                          correct_e_value=True)

        a_seq_id_0 = None
        algn = {}
        for (seq_id_0, seq_id_1, hit_rank, bit_score, e_value, coverage,
             seq_start_0, seq_end_0, seq_start_1, seq_end_1,
             ff_bool) in dbhits:
            if not a_seq_id_0: a_seq_id_0 = seq_id_0
            
            if a_seq_id_0 != seq_id_0:
                # this query is done
                for (a_seq_id_1, a) in algn.items():
                    hits.append( (a_seq_id_0, a_seq_id_1, -1, a.bit_score, a.e_value, a.coverage,
                                  -1, -1, -1, -1, a.ff_bool))
                a_seq_id_0 = seq_id_0
                algn = {}
                
            if not algn.has_key(seq_id_1):
                assert hit_rank==0, "New alignments should occur only when \
                hit_rank==0. %d %d %d" % (seq_id_0, seq_id_1, hit_rank)
                algn[seq_id_1] = alignment_coverage.combined(
                        bit_score=bit_score, e_value=e_value, ff_bool=ff_bool)
            
            algn[seq_id_1].add_unless_conflict( seq_start_0, seq_end_0,
                                                seq_start_1, seq_end_1,
                                                coverage)

        retval = hits
        if self.cacheq: pickler.cachefn(
            pickledir=self.pickledir, retval = retval,
            args = "%d_%s_%d_%s" % (self.fam_set_id, family_abbrev, br_id, correct_e_value))
        return retval


    def fetch_family_score(self, family_abbrev, br_id=None, nc_id=None, 
                           stype='bit_score', symmetric=True,
                           correct_e_value = True, min_score=None,
                           max_score=None):
        """Return (seq_id_0, seq_id_1, score, ff_bool) comprised of
the blast or nc scores for familiy and family-other pairs.
stype={bit_score,e_value, or nc_score}. If specified, return only
scores where min_score <= score < max_score."""

        assert False, "Not yet updated for new family set schema"

        if stype in ('bit_score', 'e_value') and br_id is not None:
            cond = "AND br_id='%d'\n" % br_id
            if symmetric: table = 'blast_hit_symmetric'
            else:
                table = 'blast_hit'
                cond += "AND hit_rank=0\n"
        elif stype == 'nc_score' and nc_id is not None:
            cond = "AND nc_id='%d'\n" % nc_id
            table = 'blast_hit_nc'
        else:
            assert False, "Unknown stype, br_id, or nc_id: %s, %s, %s" % (
                stype, br_id, nc_id)

        # correct e_value.  Add a very small value to handle e_value==0.
        if stype=='e_value' and correct_e_value:
            dbsize = self.fetch_db_size( br_id)
            m_stype = "-log10(e_value / '%d' + 1E-200)" % dbsize
        else:
            m_stype=stype

        thresh_cond = ""
        if min_score is not None:
            thresh_s = "AND %s >= %f\n"
            thresh_cond += thresh_s % (m_stype, min_score)
        if max_score is not None:
            thresh_s = "AND %s < %f\n"
            thresh_cond += thresh_s % (m_stype, max_score)

        q = """
-- FO, FF pairs
SELECT seq_id_0, seq_id_1, %(m_stype)s, fm0.family_id=fm1.family_id
FROM family_member as fm0
JOIN family USING (family_id)
JOIN %(table)s ON (fm0.seq_id=seq_id_0)
LEFT JOIN family_member as fm1 ON (fm1.seq_id=seq_id_1)
WHERE family_abbrev='%(family_abbrev)s'
AND set_id='%(set_id)d'
%(cond)s
%(thresh_cond)s
UNION ALL
-- OF pairs
SELECT seq_id_0, seq_id_1, %(m_stype)s, fm0.family_id=fm1.family_id
FROM family_member as fm1
JOIN family USING (family_id)
JOIN %(table)s ON (fm1.seq_id=seq_id_1)
LEFT JOIN family_member as fm0 ON (fm0.seq_id=seq_id_0)
WHERE family_abbrev='%(family_abbrev)s'
AND set_id='%(set_id)d'
AND (fm0.family_id is NULL or fm0.family_id != fm1.family_id)
%(cond)s
%(thresh_cond)s
"""
        self.dbw.ssexecute( q % {'m_stype': m_stype, 'table': table,
                                 'family_abbrev': family_abbrev, 'cond': cond,
                                 'thresh_cond': thresh_cond, 'set_id': self.fam_set_id } )

        arraysize = 100000
        while True:
            results = self.dbw.sscurs.fetchmany(arraysize)
            if not results: break
            for tup in results:
                yield tup


    def fetch_fffo_list_score(self, families, br_id=None, nc_id=None,
                              stype='bit_score', symmetric=True,
                              correct_e_value = True):
        """Return (ff_list, fo_list), comprised of the blast or nc
scores for familiy and family-other pairs.  stype={bit_score,e_value,
or nc_score}"""

        ff_list = array.array('d')
        fo_list = array.array('d')
        for family_abbrev in families:
            (ff_fam, fo_fam) = self.fetch_fffo_list_score_family(
                family_abbrev, br_id, nc_id, stype, symmetric, correct_e_value)
            ff_list.extend(ff_fam)
            fo_list.extend(fo_fam)
        
        return (ff_list, fo_list)

    def fetch_fffo_list_score_family(self, family_abbrev, br_id=None, nc_id=None, 
                                     stype='bit_score', symmetric=True,
                                     correct_e_value = True):
        """Return (ff_list, fo_list), comprised of the blast or nc
scores for familiy and family-other pairs.  stype={bit_score,e_value,
or nc_score}"""
        if self.cacheq:
            retval = pickler.cachefn(
            pickledir=self.pickledir,
            args = "%d_%s_%s_%s_%s_%s_%s" % (self.fam_set_id, family_abbrev, br_id, nc_id,
                                             stype, symmetric, correct_e_value))
            if retval: return retval

        dbret = self.fetch_family_score(family_abbrev, br_id, nc_id, stype,
                                        symmetric, correct_e_value)

        ff_list = array.array('d')
        fo_list = array.array('d')
        for (seq_id_0, seq_id_1, score, ff) in dbret:
            if ff == 1:  ff_list.append(score)
            else: fo_list.append(score)

        # work around array pickling bug
        # https://sourceforge.net/tracker/?func=detail&atid=105470&aid=1693079&group_id=5470
        if len(ff_list)==0: ff_list = []
        if len(fo_list)==0: fo_list = []

        retval = (ff_list, fo_list)
        if self.cacheq: pickler.cachefn(
            pickledir=self.pickledir, retval = retval,
            args = "%d_%s_%s_%s_%s_%s_%s" % (self.fam_set_id, family_abbrev, br_id, nc_id, stype,
                                          symmetric, correct_e_value))
        return retval
    
    
    def fetch_fffo_list_coverage( self, families, br_id, combined_coverage=False,
                                  list_type='coverage', correct_e_value=True,
                                  coverage_thresh=0.0, omit_below_thresh=False):
        """Return (ff_list, fo_list) with values of list_type =
{'coverage', 'coverage_combined', 'bit_score', or 'e_value'}.  If
coverage < coverage_thresh and list_type != 'coverage' or
'coverage_combined', a minimal pseudoscore is substituted for the real
score."""
        ff_list = array.array('d')
        fo_list = array.array('d')
        for family_abbrev in families:
            (ff_fam, fo_fam) = self.fetch_fffo_list_coverage_family( family_abbrev, br_id,
                                                                     combined_coverage, list_type,
                                                                     correct_e_value,
                                                                     coverage_thresh,
                                                                     omit_below_thresh)
            ff_list.extend(ff_fam)
            fo_list.extend(fo_fam)
        return (ff_list, fo_list)     


    def fetch_fffo_list_coverage_family( self, family_abbrev, br_id,
                                         combined_coverage=False,
                                         list_type='coverage',
                                         correct_e_value=True,
                                         coverage_thresh=0.0,
                                         omit_below_thresh=False):
        """Return (ff_list, fo_list) with values of list_type =
{'coverage', 'coverage_combined', 'bit_score', or 'e_value'}.  If
coverage < coverage_thresh and list_type != 'coverage' or
'coverage_combined', a minimal pseudoscore is substituted for the real
score."""
        if self.cacheq:
            retval = pickler.cachefn(
                pickledir=self.pickledir,
                args = "%s_%s_%s_%s_%s_%s" % (family_abbrev, br_id, combined_coverage,
                                          list_type, correct_e_value, coverage_thresh))
            if retval: return retval

        if combined_coverage:
            dbhits = self.fetch_family_coverage_combined(family_abbrev, br_id=br_id)
        else:
            dbhits = self.fetch_family_coverage( family_abbrev, br_id=br_id, hit_rank=0)

        # Make coverage symmetric.  Assume scores are already.
        d = {}
        for (seq_id_0, seq_id_1, hit_rank, bit_score, e_value, coverage,
             seq_start_0, seq_end_0, seq_start_1, seq_end_1,
             ff_bool) in dbhits:
            if not d.has_key(seq_id_0): d[seq_id_0] = {}
            if not d.has_key(seq_id_1): d[seq_id_1] = {}
            if not d[seq_id_0].has_key(seq_id_1) or d[seq_id_0][seq_id_1][0] < coverage:
                if list_type=='bit_score': score = bit_score
                elif list_type=='e_value': score = e_value
                elif list_type=='coverage': score = None
                elif list_type=='coverage_combined': score = None
                else: assert False, "Unimplemented list_type: %s" % list_type
                d[seq_id_0][seq_id_1] = (coverage, ff_bool, score)
                d[seq_id_1][seq_id_0] = (coverage, ff_bool, score)

        ff_list = array.array('d')
        fo_list = array.array('d')
        for seq_id_0 in d:
            for (seq_id_1, (coverage, ff_bool, score)) in d[seq_id_0].iteritems():
                if list_type=='coverage':
                    sval = coverage
                elif coverage < coverage_thresh:
                    if omit_below_thresh: continue
                    else:
                        # set minimum score
                        sval = 0.0
                else:
                    sval = score
                    
                if ff_bool == 1: ff_list.append(sval)
                else: fo_list.append(sval)
                    
        # work around array pickling bug
        # https://sourceforge.net/tracker/?func=detail&atid=105470&aid=1693079&group_id=5470
        if len(ff_list)==0: ff_list = []
        if len(fo_list)==0: fo_list = []

        retval = (ff_list, fo_list)
        if self.cacheq: pickler.cachefn(
            pickledir=self.pickledir,
            retval = retval,
            args = "%s_%s_%s_%s_%s_%s" % (family_abbrev, br_id, combined_coverage,
                                          list_type, correct_e_value, coverage_thresh))
        return retval


    def fetch_fffo_list_das( self, families, fpath='/net/ciona/usr0/nsong/Res_sum/Write/nc_paper/DAC_experiment/output'):
        """Return (ff_list, fo_list) for domain architecture
similarity.  Note that these scores are not stored in the database,
but in a series of flat files from Nan."""
        ff_list = array.array('d')
        fo_list = array.array('d')
        for family_abbrev in families:
            (ff_fam, fo_fam) = self.fetch_fffo_list_das_family( family_abbrev, fpath)
            ff_list.extend(ff_fam)
            fo_list.extend(fo_fam)
        return (ff_list, fo_list)     


    def fetch_fffo_list_das_family( self, family_abbrev, fpath):
        """Return (ff_list, fo_list) for domain architecture
similarity.  Note that these scores are not stored in the database,
but in a series of flat files from Nan."""

        fname_base = "contentPromisWeight_%s_%s"
        ff_list=array.array('d')
        fo_list=array.array('d')
        for (score_list, ltype) in ((ff_list, 'FF'), (fo_list, 'FO')):
            fd = open(os.path.join(fpath, fname_base % (family_abbrev, ltype)))
            for l in fd:
                # header
                if l.startswith('p1 p2 common_weight'): continue
                
                (sp0, sp1, com_wt, com_wt_jaccard, com_wt_cos) = l.split()
                score_list.append( float(com_wt_cos))
        return (ff_list, fo_list)

    
    def insert_family(self, family_name, family_abbrev, seq_ids, family_description=None):

        i = """INSERT INTO family
        (family_name, family_abbrev, family_description)
        VALUES (%(family_name)s, %(family_abbrev)s, %(family_description)s)
        RETURNING family_id"""

        i_seq = """INSERT INTO family_member
        (seq_id, fam_set_id, family_id)
        VALUES (%(seq_id)s, %(fam_set_id)s, %(family_id)s)"""

        family_id = self.dbw.fetchsingle( i, locals())
        fam_set_id = self.fam_set_id
        
        for seq_id in seq_ids:
            self.dbw.execute( i_seq, locals())

        self.dbw.commit()
