#!/usr/bin/env python

# Jacob Joseph
# 10 June 2008

# A class to store hierarchical clustering results (trees) in SQL
import cluster_obj
#from IPython.Shell import IPShellEmbed
from JJutil import pgutils, pickler
import os

class cluster(object):
    def __init__(self, cluster_run_id=None, br_id=None, nc_id=None,
                 stype=None, set_id_filter=None, score_threshold=None,
                 comment='',  params='',
                 symmetric = True,
                 cacheq=False, pickledir=None, debug=True):

        self.debug = debug
        self.cacheq = cacheq

        self.cluster_info_cache = None
        
        if pickledir != None: self.pickledir = pickledir
        else: self.pickledir = os.path.expandvars('$HOME/tmp/pickles')

        self.dbw = pgutils.dbwrap( debug=self.debug)

        if cluster_run_id is not None:
            self.cr_id = cluster_run_id
        else:
            self.cr_id = self.new_cluster_run( br_id, nc_id, stype, set_id_filter,
                                               score_threshold, comment, params,
                                               symmetric)
            print "Initialized cluster run (%d): '%s'" % (self.cr_id, comment)

    def new_cluster_run( self, br_id, nc_id, stype, set_id_filter, score_threshold,
                         comment, params, symmetric):

        if br_id is None and nc_id is not None:
            br_id = self.dbw.fetchsingle("SELECT br_id FROM nc_run where nc_id=%s",
                                         (nc_id,))
        
        i = """INSERT INTO jj_cluster_run
        (date, br_id, nc_id, stype, set_id_filter, score_threshold, comment, params, use_symmetric) VALUES
        (NOW(), %(br_id)s, %(nc_id)s, %(stype)s, %(set_id_filter)s, %(score_threshold)s,
        %(comment)s, %(params)s, %(symmetric)s)
        RETURNING cr_id"""

        lastrowid = self.dbw.fetchsingle( i, locals())

        return lastrowid

    def cluster_info( self):
        q = """SELECT date, br_id, nc_id, stype, set_id_filter, use_symmetric, score_threshold, comment, params
        FROM jj_cluster_run
        WHERE cr_id=%(cr_id)s"""

        if self.cluster_info_cache is None:
            cr_id = self.cr_id
            self.cluster_info_cache = self.dbw.fetchone_d(q, locals())

        return self.cluster_info_cache

    @property
    def cluster_params( self):
        return self.cluster_info()['params']

    @property
    def cluster_comment( self):
        return self.cluster_info()['comment']

    @property
    def stype( self):
        return self.cluster_info()['stype']

    @property
    def date( self):
        return self.cluster_info()['date']

    @property
    def score_threshold( self):
        return self.cluster_info()['score_threshold']

    @property
    def br_id( self):
        return self.cluster_info()['br_id']

    @property
    def nc_id( self):
        return self.cluster_info()['nc_id']

    @property
    def set_id_filter( self):
        return self.cluster_info()['set_id_filter']

    @property
    def symmetric( self):
        return self.cluster_info()['use_symmetric']

class flatcluster(cluster):

    def cluster_insert( self, cluster_id, seq_ids):
        i = """INSERT INTO jj_flatcluster
        (cr_id, cluster_id, seq_id) VALUES
        (%(cr_id)s, %(cluster_id)s, %(seq_id)s)"""

        for seq_id in seq_ids:
            self.dbw.execute( i, { 'cr_id': self.cr_id,
                                   'cluster_id': cluster_id,
                                   'seq_id': seq_id})
        self.dbw.commit()
        return

    def store(self, cluster_map):
        for (cluster_id, seq_ids) in cluster_map.iteritems():
            self.cluster_insert( cluster_id, seq_ids)

        return

    def fetch_clusters(self):
        if self.cacheq:
            retval = pickler.cachefn(pickledir=self.pickledir,
                                     args = "%d" % self.cr_id)
            if retval is not None: return retval

        q = """SELECT cluster_id, seq_id
        FROM jj_flatcluster
        WHERE cr_id = %(cr_id)s
        ORDER BY cluster_id"""

        self.dbw.execute(q, {'cr_id': self.cr_id})
        cluster_map = {}
        cluster_id_partial = None
        seqs_partial = []
        for (cluster_id, seq_id) in self.dbw.fetchall():
            if cluster_id != cluster_id_partial:
                if cluster_id_partial is not None:
                    cluster_map[int(cluster_id_partial)] = set(seqs_partial)
                    seqs_partial = []
                cluster_id_partial = cluster_id

            seqs_partial.append( int(seq_id))

        cluster_map[int(cluster_id_partial)] = set(seqs_partial)

        if self.cacheq: pickler.cachefn( pickledir = self.pickledir,
                                         args = "%d" % self.cr_id,
                                         retval = cluster_map)
        return cluster_map

    def fetch_cluster(self, cluster_id):
        # since flat clusterings are almost always very small, just
        # cache the entire clustering

        if not hasattr(self, 'cluster_cache'):
            self.cluster_cache = self.fetch_clusters()

        return self.cluster_cache[cluster_id]
        
    
class hcluster(cluster):
    leaf_cache = None

    def cluster_insert( self, cluster_id, distance,
                        parent_distance, parent_id, seq_id):

        i = """INSERT INTO jj_hcluster
        (cr_id, cluster_id, parent_distance, distance, parent_id, seq_id)
        VALUES
        (%(cr_id)s, %(cluster_id)s, %(parent_distance)s, %(distance)s,
        %(parent_id)s, %(seq_id)s)"""

        self.dbw.execute( i, {'cr_id': self.cr_id,
                              'cluster_id': cluster_id,
                              'distance': distance,
                              'parent_distance': parent_distance,
                              'parent_id': parent_id,
                              'seq_id': seq_id} )
        return

    def store_tree(self, root):
        """Add hierarchical clustering tree to the database."""
        clust_queue = []
        
        # (clust_id, distance, parent_distance, parent_id, seq_id, [children])
        clust_queue.append( (root.cluster_id(), root.distance(),
                             None, None, None, root.items()) )

        used_seq_ids = set()

        while True:
            if len(clust_queue) == 0:
                break

            (cluster_id, distance, parent_distance, parent_id, 
             seq_id, children) = clust_queue.pop(0) # first item

            if seq_id is not None and seq_id in used_seq_ids:
                print "Cluster %d uses an existing seq_id (%d)" % (cluster_id, seq_id)
                from IPython.Shell import IPShellEmbed
                ipsh = IPShellEmbed()
                ipsh("In store_tree():")
            
            self.cluster_insert( cluster_id, distance,
                                 parent_distance, parent_id, seq_id)

            c_parent_id = cluster_id
            c_parent_distance = distance
            
            for child in children:
                
                # leaves have sequence ids, not additional clusters
                if not isinstance(child.items()[0], cluster_obj.cobj): 
                    seq_id = child.items()[0]
                    c_children = []                    
                else:
                    seq_id = None
                    c_children = child.items()

                clust_queue.append( (child.cluster_id(), child.distance(),
                                     c_parent_distance, c_parent_id,
                                     seq_id, c_children) )

        self.dbw.commit()
        return

    def get_children(self, cluster_id):
        """Fetch the (direct) children of a cluster"""

        q = """SELECT cluster_id
        FROM jj_hcluster
        WHERE cr_id = %(cr_id)s
        AND parent_id = %(parent_id)s"""

        seq_ids = self.dbw.fetchcolumn( q, {'cr_id': self.cr_id,
                                            'parent_id': cluster_id})
        seq_ids = set( (int(seq_id) for seq_id in seq_ids) )
        return seq_ids


    def get_cluster_row(self, cluster_id):
        """Fetch a dictionary of the database row for a particular
        cluster.  Contains: distance, parent_distance, parent_id,
        seq_id, lft, and rgt."""

        q = """SELECT distance, parent_distance, parent_id, seq_id,
        lft, rgt
        FROM jj_hcluster
        WHERE cr_id = %(cr_id)s
        AND cluster_id = %(cluster_id)s"""

        ret = self.dbw.fetchone_d(q, {'cr_id': self.cr_id,
                                      'cluster_id': cluster_id } )
        return ret


    def get_distance(self, cluster_id):
        """Fetch the distance of a specific cluster."""

        dist = self.get_cluster_row(cluster_id)['distance']

        return dist

        
    def get_leaves(self, cluster_id, set_id_filter=None):
        """Fetch all leaves, using nested set indexing."""
        if False and self.cacheq:
            retval = pickler.cachefn(pickledir=self.pickledir,
                                     args = "%d_%d" % (self.cr_id,
                                                                  cluster_id))
            if retval is not None: return retval

        q_base = """WITH node AS (
        SELECT lft, rgt
        FROM jj_hcluster
        WHERE cr_id = %(cr_id)s
        AND cluster_id = %(cluster_id)s)
        SELECT seq_id
        FROM jj_hcluster"""
        
        q_filter = """
        JOIN prot_seq_set_member USING (seq_id)"""

        q_where = """
        WHERE cr_id = %(cr_id)s
        AND lft >= (SELECT lft FROM node)
        AND lft <= (SELECT rgt FROM node)
        AND seq_id IS NOT NULL"""

        q_where_filter = """
        AND set_id = %(set_id_filter)s"""

        if set_id_filter is not None:
            q = q_base + q_filter + q_where + q_where_filter
        else:
            q = q_base + q_where
        
        self.dbw.execute( q, {'cr_id': self.cr_id, 'cluster_id': cluster_id,
                              'set_id_filter': set_id_filter})
        seq_ids = self.dbw.fetchcolumn()
        seq_ids = set( (int(seq_id) for seq_id in seq_ids) )

        if False and self.cacheq:
            pickler.cachefn(pickledir=self.pickledir,
                            args = "%d_%d" % (self.cr_id,
                                                         cluster_id),
                            retval=seq_ids)
        return seq_ids
    fetch_cluster = get_leaves

    def is_child(self, cluster_id, parent_id):
        """Determines whether a cluster is a child of a particular
        parent. Returns the a boolean, and the list of ancestry from
        child to the tree root"""

        q = """SELECT parent_id
        FROM jj_hcluster
        WHERE cr_id = %(cr_id)s
        AND cluster_id = %(cluster_id)s"""

        ancestry = []
        cur_id = cluster_id
        while True:
            self.dbw.execute( q, {'cr_id': self.cr_id,
                                  'cluster_id': cur_id})
            cur_id = self.dbw.fetchsingle()
            if cur_id is None:
                break
            ancestry.append( cur_id)

        return (parent_id in ancestry, ancestry)
            

    def get_common_parent(self, cluster_ids=None, seq_ids=None):
        """Return the closest internal node that includes all clusters
        in cluster_ids or seq_ids"""

        if cluster_ids is not None:
            list_field = 'cluster_id'
            list_ids = cluster_ids
        elif seq_ids is not None:
            list_field = 'seq_id'
            list_ids = seq_ids
        else: assert False, "Must specify either a list of clusters or sequences."


        # Single cluster: return its id
        if len(list_ids) == 1:
            q_single = """SELECT cluster_id
            FROM jj_hcluster
            WHERE cr_id = %(cr_id)d
            AND %(cond)s = %(list_id)d"""

            cluster_id = self.dbw.fetchsingle(q_single % {'cr_id': self.cr_id,
                                                          'cond' : list_field,
                                                          'list_id': list(list_ids)[0]} )
            if cluster_id is not None: cluster_id = int(cluster_id)
            return cluster_id

        # Otherwise find the nearest parent
        q_ind = """SELECT MIN(lft), MAX(rgt)
        FROM jj_hcluster
        WHERE cr_id = '%(cr_id)d'
        AND %(list_field)s in %(list_id_str)s"""

        list_id_str = "("
        for list_id in list_ids:
            list_id_str += "%d," % list_id

        # remove trailing comma
        list_id_str = list_id_str[:-1]
        list_id_str += ")"

        (lft, rgt) = self.dbw.fetchone( q_ind % {'cr_id': self.cr_id,
                                                 'list_id_str': list_id_str,
                                                 'list_field': list_field} )

        q = """SELECT cluster_id
        FROM jj_hcluster
        WHERE cr_id = %(cr_id)s
        AND lft < %(lft)s
        AND rgt > %(rgt)s
        ORDER BY -lft
        LIMIT 1"""

        parent_id = self.dbw.fetchsingle( q, {'cr_id': self.cr_id,
                                              'lft': lft,
                                              'rgt': rgt} )
        if parent_id is not None: parent_id = int(parent_id)
        return parent_id

                          
    def get_leaves_recursive(self, cluster_id):
        """Fetch all leaves, recursively.  For debugging."""

        q_leaf = """SELECT seq_id
        FROM jj_hcluster
        WHERE cr_id = %(cr_id)s
        AND cluster_id = %(cluster_id)s"""

        q_rec = """SELECT cluster_id, seq_id
        FROM jj_hcluster
        WHERE cr_id = %(cr_id)s
        AND parent_id = %(parent_id)s"""

        clusters = []
        seq_ids = []
        
        self.dbw.execute( q_leaf, {'cr_id': self.cr_id,
                                   'cluster_id': cluster_id})
        seq_id = self.dbw.fetchsingle()
        if seq_id is not None:
            # given cluster is a leaf
            seq_ids.append( int(seq_id))
        else:
            # recurse from intermediate node
            clusters.append(cluster_id)
            while len(clusters) > 0:
                parent_id = clusters.pop()
                self.dbw.execute( q_rec, {'cr_id': self.cr_id,
                                          'parent_id': parent_id})
                ret = self.dbw.fetchall()
                for (cluster_id, seq_id) in ret:
                    if seq_id is not None:
                        seq_ids.append( int(seq_id))
                    else:
                        clusters.append( int(cluster_id))
                        
        return seq_ids

    def cut_tree(self, distance, set_id_filter=None):
        """Return the clusters of seq_ids resulting from a cut at
        level 'distance'"""

        if self.cacheq:
            retval = pickler.cachefn(pickledir=self.pickledir,
                                     args = "%d_%g_%s" % (self.cr_id,
                                                          distance,
                                                          set_id_filter))
            if retval is not None: return retval


        q = """SELECT cluster_id
        FROM jj_hcluster
        WHERE cr_id = %(cr_id)s
        AND distance <= %(distance)s
        AND (parent_distance > %(distance)s
             OR parent_id is NULL)"""

        self.dbw.execute( q, {'cr_id': self.cr_id, 'distance': distance} )
        clusters = self.dbw.fetchcolumn()

        ret = {}
        for cluster_id in clusters:
            leaves = self.get_leaves(cluster_id,
                                     set_id_filter = set_id_filter)
            # don't report an empty cluster. The filter could have
            # produced no leaves
            if len(leaves) == 0: continue
            
            ret[cluster_id] = leaves

        if self.cacheq:
            pickler.cachefn(pickledir=self.pickledir,
                            args = "%d_%g_%s" % (self.cr_id,
                                                 distance,
                                                 set_id_filter),
                            retval = ret)
        return ret

    def cut_tree_components(self, distance, br_id, nc_id, stype):
        """A sanity check of single linkage clustering using connected
        graph components using a particular threshold"""

        from DurandDB import blastq
        import networkx as NX
        bq = blastq.blastq()

        if stype=='nc_score':
            thresh = 1-distance
            min_score = 1-distance
            max_score = None
        elif stype=='e_value':
            thresh = distance
            min_score = None
            max_score = distance
        else: assert False, "Unsupported stype: %s" % stype
            
        nodes = bq.fetch_seq_set( br_id=br_id)
        G = NX.Graph()

        vertmap = {}
        for n in nodes:
            vertmap[n] = n
            G.add_node(n)

        hits = bq.fetch_hits_iter( br_id=br_id, nc_id=nc_id,
                                   stype=stype, min_score=min_score,
                                   max_score = max_score,
                                   correct_e_value=False, symmetric=True,
                                   self_hits=True)

        for (n1, n2, score) in hits:
            # Store each node only once, rather than a new reference
            # for each edge endpoint
            n1 = vertmap[n1]
            n2 = vertmap[n2]
            G.add_edge(n1,n2)

        ccomps = NX.connected_components(G)
        comp_dict = {}
        for (i,c) in enumerate(ccomps):
            comp_dict[i] = set(c)
            
        return comp_dict

    def fetch_structure(self, return_cluster_list=False,
                        left_lim=None, right_lim=None):
        """Return the entire tree, as a cluster_obj root.  Assumes the
        tree was previously 'nested'."""

        q = """SELECT cluster_id, parent_id, distance, seq_id, lft, rgt
        FROM jj_hcluster
        WHERE cr_id = %(cr_id)s"""

        q_range = """\nAND lft >= %(left_lim)s
        AND rgt <= %(right_lim)s"""

        if left_lim is not None and right_lim is not None:
            q += q_range

        q += """\nORDER BY lft"""
        
        id_to_cluster = {}
        root = None

        for (cluster_id, parent_id, distance, seq_id, lft, rgt
             ) in self.dbw.fetchall( q, {'cr_id': self.cr_id,
                                         'left_lim': left_lim,
                                         'right_lim': right_lim}):

            items = [] if seq_id is None else [seq_id]

            assert lft is not None, "Must call nest_sets before this function"

            c = cluster_obj.cobj( distance = distance,
                                  cluster_id = cluster_id,
                                  left = lft,
                                  right = rgt,
                                  items = items)

            id_to_cluster[cluster_id] = c
            
            #if parent_id is None:
            # allow for starting at a subtree
            if len(id_to_cluster) == 1:
                root = c
            else:
                # the parent should always exist with ORDER BY lft
                id_to_cluster[parent_id].add_child(c)

        assert root is not None, "Did not find root"

        if return_cluster_list:
            return (root, id_to_cluster.values())
        else:
            return root
    
    def nest_sets(self):
        """Complete the lft and rght columns of the tree, for a nested
        sets representation"""

        q_root = """SELECT cluster_id
        FROM jj_hcluster
        WHERE cr_id = %(cr_id)s
        AND parent_id IS NULL"""

        # order by really isn't necessary, but gives predictable results
        q_children = """SELECT cluster_id
        FROM jj_hcluster
        WHERE cr_id = %(cr_id)s
        AND parent_id = %(parent_id)s
        ORDER BY cluster_id
        """
        u = """UPDATE jj_hcluster
        SET lft = %(lft)s, rgt = %(rgt)s
        WHERE cr_id = %(cr_id)s
        AND cluster_id = %(cluster_id)s"""

        stack = []
        #done = []

        self.dbw.execute( q_root, {'cr_id': self.cr_id})
        root_id = self.dbw.fetchsingle()
        stack.append( cluster_obj.cobj( cluster_id=root_id,
                                        left=1, right=None))
        last_i = 1

        while len(stack) > 0:
            clust = stack[-1]
            
            # set left if not already
            if clust.left() is None:
                last_i += 1
                clust.set_left( last_i)

            if not clust.traversed():
                self.dbw.execute( q_children, {'cr_id': self.cr_id,
                                               'parent_id': clust.cluster_id() })
                cluster_ids = self.dbw.fetchcolumn()
                clust.set_traversed()
            else:
                cluster_ids = []

            # children do exist
            if len(cluster_ids) > 0:

                children = [ cluster_obj.cobj(cluster_id=cid,
                                              left=None, right=None)
                             for cid in cluster_ids ]
                # sql ORDER BY can only use indicies for +cluster_id,
                # whereas we'd like the minimum cluster_id at the end
                # of the stack
                children.reverse()

                # set left on the last child (left-most leaf)
                last_i += 1
                children[-1].set_left( last_i)

                stack.extend( children)
            # no children.  Set right to the current max+1.
            # Pop from the stack (, inserting into the database)
            else:
                last_i += 1
                clust.set_right( last_i)
                # pop the current cluster from the stack
                #done.append( stack.pop())
                stack.pop()
                self.dbw.execute( u, {'cr_id': self.cr_id,
                                      'cluster_id': clust.cluster_id(),
                                      'lft': clust.left(),
                                      'rgt': clust.right()})

        self.dbw.commit()
                
        return
