#!/usr/bin/env python

# Jacob Joseph
# 16 July 2008

import os, sys
from JJutil import pgutils, pickler
from JJutil.UnionFind import UnionFind

class pfamq:
    """Class to facilitate in calculations against PFAM, as stored in
    prot_seq_pfam."""

    dbw = None

    def __init__(self, cacheq=False, pickledir=None):
        self.cacheq = cacheq
        
        if pickledir != None: self.pickledir = pickledir
        else: self.pickledir = os.path.expandvars('$HOME/tmp/pickles')

        if not os.path.exists(self.pickledir): os.mkdir(self.pickledir)

        self.dbw = pgutils.dbwrap()

    def lookup_domain_name(self, hmm_acc):

        q = """SELECT hmm_name
        FROM prot_seq_pfam
        WHERE hmm_acc = %(hmm_acc)s
        LIMIT 1"""

        ret = self.dbw.fetchsingle(q, locals())

        if ret is None:
            print >> sys.stderr, "WARNING: hmm_acc '%(hmm_acc)s' not found in current sequences." % locals()

        return ret

    def fetch_seq_domains(self, seq_id, field_list=None):
        """Fetch the list of domains for a given sequence or set of
        sequences"""

        if field_list is None:
            field_list = ('hmm_acc', 'hmm_name')
        else:
            field_list = tuple(field_list)

        field_str = reduce( lambda a,b: a + ", " + b, field_list)

        q_base = """SELECT %(field_str)s
        FROM prot_seq_pfam""" % {'field_str': field_str}

        q_where = """
        WHERE seq_id = %(seq_id)s"""

        q = q_base + q_where

        retlist = self.dbw.fetchall( q, locals())

        return retlist


    def fetch_domain_map(self, set_id = None, seq_ids = None):
        """For a given set_id or set/list of sequences, return
        dictionaries to (1) look up all sequences with a particular
        domain, and (2) all domains in a sequence.e.g.,
          ({domain0: set(seq_id0, seq_id1, ...), ...},
           {seq_id0: set(domain0, domain1, ...), ...}

        Note that order, or multiple copies of a domain in a sequence
        will not be represented.
        """

        assert set_id is not None or seq_ids is not None, "At least one of set_id or seq_ids must be specified."
        
        q_base = """SELECT distinct seq_id, hmm_acc
        FROM prot_seq_pfam
        %(q_join)s
        WHERE %(q_cond)s"""

        q_join = ""
        q_cond = ""
        if set_id is not None:
            q_join += """\nJOIN prot_seq_set_member USING (seq_id)"""
            if len(q_cond) > 0: q_cond += "\n"
            q_cond += """set_id = %(set_id)s"""

        if seq_ids is not None:
            if len(q_cond) > 0: q_cond += "\nAND "
            q_cond += """seq_id IN %(seq_ids)s"""

        q = q_base % locals()

        dbret = self.dbw.fetchall( q, locals())
        
        domain_map = {}
        seq_map = {}

        for seq_id, hmm_acc in dbret:
            try:
                domain_map[hmm_acc].add(seq_id)
            except KeyError:
                domain_map[hmm_acc] = set( (seq_id,))

            try:
                seq_map[seq_id].add(hmm_acc)
            except KeyError:
                seq_map[seq_id] = set( (hmm_acc,))

        return (domain_map, seq_map)
