#!/usr/bin/env python

# Jacob Joseph
# 16 July 2008

import os
from JJutil import pgutils, pickler
from JJutil.UnionFind import UnionFind

class pfamq:
    """Class to facilitate in calculations against PFAM.  Note that
    only the PFAM data in Uniprot should be considered robust.  Other
    data, such as in tables 'pfam22*' is not generated with Uniprot
    updates."""

    dbw = None

    def __init__(self, cacheq=False, pickledir=None):
        self.cacheq = cacheq
        
        if pickledir != None: self.pickledir = pickledir
        else: self.pickledir = os.path.expandvars('$HOME/tmp/pickles')

        self.dbw = pgutils.dbwrap()
        
    def fetch_pfam_source_id(self):
        q_pfam_source_id = """SELECT source_id
        FROM prot_seq_source
        WHERE source_name='Pfam'"""
        
        source_id = self.dbw.fetchsingle( q_pfam_source_id)
        return source_id

    def fetch_domains(self, set_id, seq_id_list=None):
        """Fetch a list of (pfam id, count) tuples for a set_id, and, optionally,
        a subset of seq_ids."""


        q = """SELECT id_0, COUNT(seq_id)
        FROM prot_seq_set_member
        JOIN sp_cross_ref USING (seq_id)
        WHERE set_id=%(set_id)d
        AND source_id=%(source_id)d"""

        seq_id_str = ""
        if seq_id_list is not None:
            q += "\n AND seq_id in (%(seq_id_str)s)"

            # Build seq_id list
            # accept either a single seq_id or a list of them
            if type(seq_id_list) is int:
                seq_id_list = [ seq_id_list]

            for seq_id in seq_id_list:
                seq_id_str += "%d," % seq_id

            # remove trailing comma
            seq_id_str = seq_id_str[:-1]

        q += "\nGROUP BY id_0"
            
        source_id = self.fetch_pfam_source_id()

        dbret = self.dbw.fetchall( q % {'set_id': set_id,
                                        'source_id': source_id,
                                        'seq_id_str': seq_id_str})
        
        return dbret

    def fetch_domains_seqwise(self, seq_id):
        """Fetch a list of (domain id, instances) tuples for a single
        sequence.  Here, we rely upon the id_2 column for domain count."""

        q = """SELECT id_0, id_2
        FROM sp_cross_ref
        WHERE seq_id=%(seq_id)s
        AND source_id=%(source_id)s"""
        
        source_id = self.fetch_pfam_source_id()

        dbret = self.dbw.fetchall( q, locals())
        return dbret
        

    def get_domains(self, seq_id_list):
        """Returns the set of pfam domains that are contained within a list of seq_ids"""

        domain_lst = self.fetch_domains(seq_id_list = seq_id_list)
        return set([ a[0] for a in domain_list])


    def get_identity_map(self):
        """Return a dictionary of domain identities"""

        if self.cacheq:
            retval = pickler.cachefn(pickledir=self.pickledir)
            if retval: return retval

        q = """SELECT id_0, identity
        FROM pfam_22_stat"""

        self.dbw.execute(q)

        ident_map = {}
        for (id_0, identity) in self.dbw.fetchall():
            ident_map[id_0] = identity

        if self.cacheq:
            pickler.cachefn(pickledir = self.pickledir,
                            retval = ident_map)

        return ident_map


    def get_promiscuity_map(self, set_id, seq_id_list = None,
                            promisc_type = 'count',
                            use_cnt_promisc_tuple=False):
        """Return a dictionary of domain promiscuities. Calculated
        using a set_id and, possibly, a subset of seq_ids.
        promisc_type may be one of 'count' or 'neighborhood'."""

        if self.cacheq:
            retval = pickler.cachefn(pickledir=self.pickledir)
            if retval: return retval

        if promisc_type=='count':
            promisc_map = self.get_promisc_map_count(
                set_id = set_id, seq_id_list = seq_id_list,
                use_cnt_promisc_tuple = use_cnt_promisc_tuple)

        elif promisc_type == 'component':
            promisc_map = self.get_promisc_map_ccomp(
                set_id = set_id, seq_id_list = seq_id_list,
                use_cnt_promisc_tuple = use_cnt_promisc_tuple)
        
        elif promisc_type == 'neighborhood':
            promisc_map = self.get_promisc_map_neigh(
                set_id = set_id, seq_id_list = seq_id_list,
                use_cnt_promisc_tuple = use_cnt_promisc_tuple)

        else:
            assert False, "Unknown promisc_type: %s" % promisc_type

        if self.cacheq:
            pickler.cachefn(pickledir = self.pickledir,
                            retval = promisc_map)

        return promisc_map


    def get_promisc_map_count(self, set_id, seq_id_list = None,
                                  use_cnt_promisc_tuple=False):
        """Return a dictionary of domain promiscuities. Calculated
        using a set_id and, possibly, a subset of seq_ids"""

        dom_cnt = self.fetch_domains( set_id, seq_id_list)

        max_count = None
        promis_map = {}
        for (domain, count) in dom_cnt:
            promis_map[domain] = count
            if max_count is None or count > max_count:
                max_count = count

        # normalize by the greatest domain promiscuity count
        for (domain, count) in promis_map.items():
            if use_cnt_promisc_tuple:
                promis_map[domain] = (count, float(count) / max_count)
            else:
                promis_map[domain] = float(count) / max_count

        return promis_map


    def get_promisc_map_ccomp(self, set_id, seq_id_list = None,
                              use_cnt_promisc_tuple=False):
        components = self.get_domain_components( set_id, seq_id_list)

        # count the component size of each domain
        max_neighs = 0
        num_neighs = {}
        for comp in components:
            n = len(comp) - 1              # don't count self
            max_neighs = max(n, max_neighs)

            for dom in comp:
                num_neighs[dom] = n

        for (dom, n) in num_neighs.iteritems():
            if use_cnt_promisc_tuple:
                num_neighs[dom] = (n, float(n) / max_neighs)
            else:
                num_neighs[dom] = float(n) / max_neighs

        return num_neighs

    def get_promisc_map_neigh(self, set_id, seq_id_list = None,
                              use_cnt_promisc_tuple=False):
        seq_domains = self.fetch_sequence_domains( set_id, seq_id_list)

        # build sets of domain neighbors. These include self
        domain_neighs = {}
        max_num_neighs = 0
        for (seq_id, domains) in seq_domains.iteritems():
            for domain in domains:
                if not domain in domain_neighs:
                    domain_neighs[domain] = set()

                domain_neighs[domain].update( domains)
                max_num_neighs = max(len(domain_neighs[domain]),
                                     max_num_neighs)

        # subtract self from max_num_neighs
        max_num_neighs -= 1
        
        for (dom, nn_set) in domain_neighs.iteritems():
            n = len(nn_set) - 1 # subtract self
            if use_cnt_promisc_tuple:
                domain_neighs[dom] = (n, float(n) / max_num_neighs)
            else:
                domain_neighs[dom] = float(n) / max_num_neighs

        return domain_neighs


    def fetch_sequence_domains(self, set_id, seq_id_list = None,
                               tuple_domain_cnt=False,
                               tuple_domain_cnt_and_name=False):
        """Fetch a set of domains for each sequence."""

        q = """SELECT seq_id, id_0, id_1, id_2
        FROM prot_seq_set_member
        JOIN sp_cross_ref USING (seq_id)
        WHERE set_id=%(set_id)d
        AND source_id=%(source_id)d
        """

        seq_id_str = ""
        if seq_id_list is not None:
            q += "\n AND seq_id in (%(seq_id_str)s)"

            # Build seq_id list
            # accept either a single seq_id or a list of them
            if type(seq_id_list) is int:
                seq_id_list = [ seq_id_list]

            for seq_id in seq_id_list:
                seq_id_str += "%d," % seq_id

            # remove trailing comma
            seq_id_str = seq_id_str[:-1]

        q += "\n ORDER BY seq_id"
        
        source_id = self.fetch_pfam_source_id()

        dbret = self.dbw.fetchall( q % {'set_id': set_id,
                                        'source_id': source_id,
                                        'seq_id_str': seq_id_str})
        seqs = {}
        for (seq_id, pf_id, name, cnt) in dbret:
            if not seq_id in seqs:
                seqs[seq_id] = set()

            if tuple_domain_cnt_and_name:
                seqs[seq_id].add( (pf_id, cnt, name))
                
            elif tuple_domain_cnt:
                seqs[seq_id].add( (pf_id, cnt))
                
            else:
                seqs[seq_id].add( pf_id)

        return seqs
        

    def get_domain_components(self, set_id, seq_id_list = None):
        """Build a set of connected components in the domain graph,
        where an edge exists between domains that are found within a
        single protein.  The domain graph is represented only
        implicitly, and Union-Find is used to efficiently identify
        components."""

        seq_domains = self.fetch_sequence_domains( set_id, seq_id_list)

        # Use Union-Find to efficiently build the set of connected components
        union_tree = UnionFind()

        # rows are ordered by seq_id
        for (seq_id, domains) in seq_domains.iteritems():
            union_tree.union( *domains)
        
        components = {}
        for (obj, comp) in union_tree.iteritems():
            if not comp in components:
                # cnt, set
                components[comp] = set()

            components[comp].add( obj)

        return components.values()

    
    def fetch_domain_name(self, pf_list, set_id):

        # allow a single PFAM name, or a list of them
        if type(pf_list) is str:
            pf_list = [ pf_list]

        # fails to return anything sometimes?  Is there a max query
        # list length of 1351 different items?
        q_lst = """SELECT id_0, id_1
        FROM sp_cross_ref
        WHERE source_id=%(source_id)d
        AND id_0 in (%(pf_list_str)s)
        GROUP BY id_0"""

        q_single = """SELECT id_1
        FROM sp_cross_ref
        JOIN prot_seq_set_member using (seq_id)
        WHERE source_id=%(source_id)s
        AND id_0=%(pf)s
        AND set_id=%(set_id)s
        LIMIT 1"""

        source_id = self.fetch_pfam_source_id()

        # Build accession string
        #pf_list_str = ""
        #for pf in pf_list:
        #    pf_list_str += "'%s'," % pf
        #pf_list_str = pf_list_str[:-1]

        #dbret =  self.dbw.fetchall( q % { 'source_id': source_id,
        #                                  'pf_list_str': pf_list_str})
        name_map = {}        
        #for (pf, name) in dbret:
        #    name_map[pf] = name

        for pf in pf_list:
            name = self.dbw.fetchsingle( q_single, locals())
            name_map[pf] = name
            
        return name_map
        
    
