#!/usr/bin/env python

# Jacob Joseph
# 2008 June 4
# Agglomerative clustering routines

from DurandDB import blastq
import networkx as NX
from JJcluster import cluster_obj
import time
from IPython.Shell import IPShellEmbed
from JJutil.pqueue import pqueue
from JJutil.priority_dict import priority_dict

class agglomerative_base(object):

    def __init__(self, br_id=None, nc_id=None,
                 stype='e_value', self_hits = True,
                 symmetric = True,
                 set_id_filter=None,
                 score_threshold=None,
                 param_list=None  
                 ):
        self.br_id = br_id
        self.nc_id = nc_id
        self.stype = stype
        self.self_hits = self_hits
        self.symmetric = symmetric
        self.set_id_filter = set_id_filter
        self.score_threshold = score_threshold
        self.param_list = param_list  # unused in these classes
        
        self.linkage_fn = self.linkage_fn_stub

        self.bq = blastq.blastq()
        self.all_seqs = self.bq.fetch_seq_set( nc_id = nc_id,
                                               br_id=br_id,
                                               set_id=self.set_id_filter)

        self.set_max_distance()
        
        self.seq_to_clust = {}
        self.clust_to_seq = {}

        self.last_clust_id = -1

        self.clust_nn = {}

        self.last_level=0

    def set_max_distance(self):
        if self.stype == 'nc_score':
            self.max_distance = 1.0
        elif self.stype == 'bit_score':
            self.max_distance = self.bq.fetch_max_bit_score( br_id = self.br_id,
                                                             symmetric = self.symmetric,
                                                             seq_id_0_set_id = self.set_id_filter)
        elif self.stype == 'e_value':
            self.max_distance = self.bq.fetch_blast_params(self.br_id)['expectation']
        else:
            assert False, "Unknown stype: %s" % self.stype
        return

    def linkage_fn_stub(self, *args):
        assert False, "No agglomerative linkage function defined"


    def next_cluster_id(self):
        self.last_clust_id += 1
        return self.last_clust_id
    

    def pair_components(self, min_pairs):
        """Return the set of connected components that represent pairs
        of equivalent distance. Each component will be represented by
        a single cluster at the next higher level."""

        G = NX.Graph( min_pairs)
        components = NX.connected_components(G)

        return components


    def get_leaves_recursive(self, clust):
        """Return the set of sequence ids associated with a cluster."""

        clusters = []
        seq_ids = []

        clusters.append(clust)

        while len(clusters) > 0:
            c = clusters.pop()

            # sequence_id
            if not isinstance(c.items()[0], cluster_obj.cobj):
                seq_ids.append( c.items()[0])
            else:
                clusters.extend( c.items())

        return seq_ids

    def verify_tree(self):
        """Verify that every cluster in clust_to_seq.keys() represents
        a unique set of leaves"""

        seq_id_set = set()

        for c in self.clust_to_seq:
            seq_ids = self.get_leaves_recursive(c)
            if len(seq_id_set.intersection( seq_ids)) > 0:
                print "Cluster_id %d overlaps with existing sequence set" % c.cluster_id()
                ipsh = IPShellEmbed()
                ipsh("In verify_tree():")
            else:
                seq_set = set(seq_ids)
                if not seq_set==self.clust_to_seq[c]:
                    print "Cluster_id %d sequence set disagrees with recursive search" % c.cluster_id()
                    ipsh = IPShellEmbed()
                    ipsh("In verify_tree():")
                    
                
                seq_id_set.update( seq_set)

    def verify_distances(self, parent):
        """Verify that cluster distance never decreases from child to
        parent in the subtree beneith cluster 'parent'."""

        parents = [parent]
        while len(parents) > 0:
            p = parents.pop()
            
            children = p.items()
            for child in children:
                if not isinstance(child, cluster_obj.cobj):
                    # leaf
                    if len(children) != 1:
                        print "Cluster %d seems to be a leaf, but has multiple children." % (
                            p.cluster_id() )
                        ipsh = IPShellEmbed()
                        ipsh("In verify_distances():")
                    continue

                if child.distance() > p.distance():
                    print "Child %d of parent %d, with distance %g > %g" % (
                        child.cluster_id(), p.cluster_id(),
                        child.distance(), p.distance())
                    ipsh = IPShellEmbed()
                    ipsh("In verify_distances():")
                
                parents.append(child)
                
        return

    def print_hierarchy(self, cluster):
        """Print a hierarchical subtree"""

        return cluster.print_hierarchy()

    
    def cluster(self):
        print "Beginning clustering - ", time.strftime('%H:%M:%S (%d %h %Y)')
        
        # create a single cluster for each sequence
        for seq_id in self.all_seqs:
            c = cluster_obj.cobj(distance=0,
                                 cluster_id=self.next_cluster_id(),
                                 items=[seq_id])

            # sequence to cluster map
            self.seq_to_clust[seq_id] = c

            # cluster to sequence set map
            if not c in self.clust_to_seq:
                self.clust_to_seq[c] = set()
            self.clust_to_seq[c].add( seq_id)
            #self.verify_distances(c)

        #self.verify_tree()

        # merge the smallest connected components each to one cluster
        while True:

            # dict of dicts of clusters to join
            cluster_pairs = self.linkage_fn()

            # we've merged all of the connected components.  Merge all
            # remaining clusters to a root node with maximum distance.
            if cluster_pairs is None:
                self.join( self.max_distance, self.clust_to_seq.keys())
                #self.verify_tree()
                break

            else:
                # each connected component can be merged to one cluster
                components = self.pair_components( cluster_pairs)
                dist = cluster_pairs.values()[0].values()[0]
                
                for comp in components:
                    self.join( dist, comp)
                    #self.verify_tree()

        print "Done -", time.strftime('%H:%M:%S (%d %h %Y)')
        return


    def join(self, level, clusters):
        """Merge clusters to a new cluster, updating self.clust_to_seq
        and self.seq_to_clust"""

        c = cluster_obj.cobj(distance=level,
                             cluster_id=self.next_cluster_id(),
                             items=clusters)

        if False and c.cluster_id() == 27124:
            children = c.items()
            c0items = self.get_leaves_recursive(children[0])
            c1items = self.get_leaves_recursive(children[1])

            print self.print_hierarchy(c)
            
            print "Cluster %d distance: %g" % (c.cluster_id(), c.distance())
            print "Child    Distance    NN_seq  NN_cluster  NN_dist:"
            for child in children:
                leaves = self.get_leaves_recursive(child)
                min_dist = self.min_distance( leaves)
                if min_dist is not None:
                    (seq_id_0, seq_id_1, dist) = min_dist
                    nn_id = self.seq_to_clust[ seq_id_1].cluster_id()
                else:
                    nn_id = None
                    seq_id_1 = None
                    
                print child.cluster_id(), child.distance(), seq_id_1, nn_id, dist

            self.min_distance(c1items, c0items)
            c0children = children[0].items()
            c78leaves = self.get_leaves_recursive(c0children[0])
            c27110leaves = self.get_leaves_recursive(c0children[1])
            c75leaves = self.get_leaves_recursive(c0children[2])

            ipsh = IPShellEmbed()
            ipsh("Just created cluster %d" % c.cluster_id())
            return c

        # move sequence assignments from the set of clusters to the
        # single new cluster
        seq_set = set()
        nn_list = []
        for clust in clusters:
            seq_set.update( self.clust_to_seq.pop(clust))
            nn_list.append( self.clust_nn.pop( clust))
            
        self.clust_to_seq[c] = seq_set

        for seq_id in seq_set:
            self.seq_to_clust[seq_id] = c

        self.join_nn(c, seq_set, nn_list)

        #self.verify_distances(c)

        m = "Merged %d clusters at distance %g in cluster %d. "
        m += "(Distinct cluster count: %d)"
        print m % (len(clusters), level, self.last_clust_id,
                   len(self.clust_to_seq))

        return c


    def join_nn(self, c, seq_set, nn_list):
        """Derive the new nearest neighbor from a joined group of
        clusters.  Usually we'll rely on a new distance calculation,
        but single linkage can be faster.  That is, with single
        linkage, the shortest distance not within the merged cluster
        is still the best."""

        # None if no hits
        # No key if needs a lookup
        return


class agglomerative_memory( agglomerative_base):

    def build_distance(self):
        """Build the distance matrix, and priority queues described in
        Day & Edelsbrunner, step 0."""

        assert self.symmetric, "This function assumes symmetric scores"
        
        # distance matrix
        self.diss = {}
        diss = self.diss
        hits = self.bq.fetch_hits_direct( br_id=self.br_id, nc_id=self.nc_id,
                                          stype=self.stype, 
                                          thresh=self.score_threshold,
                                          set_id_filter=self.set_id_filter,
                                          self_hits=False,
                                          correct_e_value=False,
                                          symmetric=self.symmetric)
        for (i,(seq_id_0, seq_id_1, score)) in enumerate(hits):
            if i % 1000000 == 0: print "hit:", seq_id_0, i
            
            # here, there should already be exactly one cluster per
            # sequence
            c0 = self.seq_to_clust[seq_id_0]
            c1 = self.seq_to_clust[seq_id_1]

            # Make a distance from the similarity measure
            if self.stype in ('nc_score', 'bit_score'):
                score = self.max_distance - score

            if not c0 in diss: diss[c0] = priority_dict()
            if not c1 in diss: diss[c1] = priority_dict()
            if not c1 in diss[c0]:
                # assumes symmetric scores
                # FIXME: while complicating queries, half could be
                # stored, but this would break the priority dicts
                diss[c0][c1] = score
                diss[c1][c0] = score
        return


    def init_pqueue(self):
        """Day & Edelsbrunner, step 0. (Unused; Now accomplished in
        build_distance().)"""

        self.nextbest = {}

        for (i,c0) in enumerate(self.diss.keys()):
            print i, c0
            
            for (c1, dist) in self.diss[c0].items():
                if c0 == c1: continue
                self.nextbest[c0].push( (dist, c1))

        return

    def min_pair(self):
        """Day & Edelsbrunner, step 1"""

        min_tup = None
        for c0 in self.diss.keys():
            if len(self.diss[c0]) == 0: continue
            
            c1 = self.diss[c0].smallest()
            dist = self.diss[c0][c1]
            if min_tup is None or dist < min_tup[2]:
                min_tup = (c0, c1, dist)

        # pop this minimum from the nextbest queue
        if min_tup is not None:
            self.diss[ min_tup[0] ].pop_smallest()

        return min_tup

    def join(self, level, clusters):
        c = cluster_obj.cobj(distance=level,
                             cluster_id=self.next_cluster_id(),
                             items=clusters)

        # move sequence assignments from the set of clusters to the
        # single new cluster
        seq_set = set()
        for clust in clusters:
            seq_set.update( self.clust_to_seq.pop(clust))
            
        self.clust_to_seq[c] = seq_set

        for seq_id in seq_set:
            self.seq_to_clust[seq_id] = c

        m = "Merged %d clusters at distance %g in cluster %d. "
        m += "(Distinct cluster count: %d)"
        print m % (len(clusters), level, self.last_clust_id,
                   len(self.clust_to_seq))

        return c

 
    def cluster(self):
        print "Beginning clustering - ", time.strftime('%H:%M:%S (%d %h %Y)')
        
        # create a single cluster for each sequence
        for seq_id in self.all_seqs:
            seq_id = int(seq_id)
            c = cluster_obj.cobj(distance=0,
                                 cluster_id=self.next_cluster_id(),
                                 items=[seq_id])

            # sequence to cluster map
            self.seq_to_clust[seq_id] = c

            # cluster to sequence set map
            if not c in self.clust_to_seq:
                self.clust_to_seq[c] = set()
            self.clust_to_seq[c].add( seq_id)


        print "Building distance matrix"
        self.build_distance()

        debug0=True
        print "Entering cluster loop"
        while True:

            # step 1
            cluster_pair = self.min_pair()
            #print "min pair:", cluster_pair

            #if debug0:
            #    ipsh = IPShellEmbed()
            #    ipsh("After min pair:")

            # all connected components already merged.  Merge the rest
            # into a root node with maximum distance
            if cluster_pair is None:
                self.join( self.max_distance, self.clust_to_seq.keys())
                #self.verify_tree()
                break

            (c0, c1, dist) = cluster_pair

            # step 2:
            try:
                c_new = self.join( dist, [c0,c1])
            except Exception, e:
                print e
                ipsh = IPShellEmbed()
                ipsh("join exception:")

            # step 3, 4: update D, priority queues
            try:
                self.update_distance(c_new, [c0, c1])
            except Exception, e:
                print e
                ipsh = IPShellEmbed()
                ipsh("update distance exception:")

        print "Done -", time.strftime('%H:%M:%S (%d %h %Y)')
        return

    def update_distance(*args):
        assert False, "update_distance() should be implemented in a subclass"


class single_linkage( agglomerative_base):
    """Single linkage agglomerative clustering from Sibson, 1972:
    SLINK: An optimally efficient algorithm for the single-link
    cluster method"""

    class res:
        def __init__(self, dist, left):
            self.dist = dist
            self.left = left
            self.right = None

        def __repr__(self):
            m = (self.dist, self.left, self.right).__str__()
            return m


    def __init__(self, br_id=None, nc_id=None, stype='e_value',
                 self_hits = True, 
                 symmetric=True,
                 set_id_filter=None,
                 score_threshold=None,
                 param_list=None):
        agglomerative_base.__init__(self, br_id=br_id, nc_id=nc_id,
                                    stype=stype,
                                    self_hits=self_hits,
                                    symmetric=symmetric,
                                    set_id_filter=set_id_filter,
                                    score_threshold=score_threshold,
                                    param_list=param_list)

    def cluster(self):
        """Closely modeled after Cluster 3.0 pslcluster() function"""
        
        print "Beginning clustering - ", time.strftime('%H:%M:%S (%d %h %Y)')

        nodes = sorted(self.all_seqs)
        n_hash = dict( zip(nodes, range(len(nodes))) )
        
        # 1) set PI(n+1) to n+1, Gamma(n+1) to inf
        vector = range( len(nodes))       # PI
        result = [self.res(1E100, i) for i in range(len(nodes))] # GAMMA

        for (i,seq_id) in enumerate(nodes):
            if i % 1000==0: print "Query: ", i, seq_id

            temp = self.get_row( seq_id, n_hash)
            
            for j in range(i):
                # if j >= i: break

                k = vector[j]
                
                # if Gamma(i) >= M(i)
                if result[j].dist >= temp[j]:
                    # set M(PI(i)) = min{ M(PI(i)), Gamma(i) }
                    if result[j].dist < temp[k]: temp[k] = result[j].dist
                    result[j].dist = temp[j]
                    vector[j] = i
                    
                # set M(PI(i)) to min{ M(PI(i)), M(i) }
                elif temp[j] < temp[k]:
                    temp[k] = temp[j]

            for j in range(i):
                if result[j].dist >= result[ vector[j]].dist:
                    vector[j] = i

        root = self.dendrogram( result, vector, nodes)

        return
        
    def dendrogram(self, result, vector, nodes):
        node_order = range(len(nodes))
        node_order.sort(key=lambda i: result[i].dist)

        index = range(len(nodes))

        clusters = {}

        for (ind,i) in enumerate(node_order):
            j = result[i].left
            k = vector[j]
            result[i].left = index[j]
            result[i].right = index[k]
            index[k] = -ind-1
        
        int_label = 0
        for i in node_order[:-1]:
            l = result[i].left
            r = result[i].right

            # Create leaf clusters
            for n in (l, r):
                if n not in clusters:
                    assert n >=0, "Impossible"
                    node = nodes[n]
                    
                    # create leaf
                    c = cluster_obj.cobj(distance = 0,
                                         cluster_id = self.next_cluster_id(),
                                         items = [node])
                    clusters[n] = c

                    self.seq_to_clust[node] = c
                    self.clust_to_seq[c] = set( (node,))

            c = cluster_obj.cobj(distance = result[i].dist,
                                 cluster_id = self.next_cluster_id(),
                                 items = [clusters[l], clusters[r]])
            
            # update the cluster and sequence maps
            seq_set = set()
            seq_set.update( self.clust_to_seq.pop( clusters[l]))
            seq_set.update( self.clust_to_seq.pop( clusters[r]))
            self.clust_to_seq[c] = seq_set
            for seq_id in seq_set: self.seq_to_clust[seq_id] = c

            int_label += -1
            clusters[int_label] = c

        return

    def get_row(self, seq_id, n_hash):
        d = [1E100] * len(n_hash)
        
        hits = self.bq.fetch_hits_direct(
            br_id=self.br_id, nc_id=self.nc_id,
            stype=self.stype, self_hits=False,
            correct_e_value=False,
            symmetric=self.symmetric,
            query_seq_id = seq_id,
            set_id_filter=self.set_id_filter,
            thresh=self.score_threshold)

        for (seq_id_0, seq_id_1, score) in hits:

            # Make a distance from the similarity measure
            if self.stype in ('nc_score', 'bit_score'):
                score = self.max_distance - score

            d[ n_hash[ seq_id_1]] = score

        return d

class single_linkage_sql( agglomerative_base):
    """Single linkage agglomerative clustering using SQL.  Very low
    memory requirements, but can be slow."""
    
    def __init__(self, br_id=None, nc_id=None, stype='e_value',
                 self_hits=True,
                 symmetric=True,
                 param_list=None):
        agglomerative_base.__init__(self, br_id, nc_id, stype, self_hits,
                                    symmetric = symmetric,
                                    param_list = param_list)
        
        self.linkage_fn = self.single_linkage

    def join_nn(self, c, seq_set, nn_list):
        """Derive the new nearest neighbor from a joined group of
        clusters.  Usually we'll rely on a new distance calculation,
        but single linkage can be faster.  That is, with single
        linkage, the shortest distance not within the merged cluster
        is still the best."""

        # nn_list: list of (seq_id_0, seq_id_1, dist) tuples
        nn_min = None
        for tup in nn_list:
            if tup is not None:
                (seq_id_0, seq_id_1, dist) = tup
            else:
                continue
                
            assert seq_id_0 in seq_set, "seq_id_0 %d not found in seq_set" % seq_id_0
            
            if seq_id_1 in seq_set: continue
            
            if nn_min is None or dist < nn_min[2]:
                nn_min = (seq_id_0, seq_id_1, dist)

        if nn_min is not None:
            self.clust_nn[c] = nn_min
        else:
            return

    def single_linkage(self):
        """single: find the smallest distance from members of one
        cluster to any other member, then look up the cluster
        associated with that other member"""

        # find the nearest sequences not in the same cluster
        # return all minimal distance pairs
        min_pairs = None
        min_dist = None

        for c0 in self.clust_to_seq.keys():
            x = self.clust_to_seq[ c0]

            if c0 in self.clust_nn:
                if self.clust_nn[c0] is None:  # no hits
                    continue
                else:
                    (seq_id_0, seq_id_1, dist) = self.clust_nn[c0]
            else:
                ret = self.min_distance( x, None)

                # retain the nearest (seq_id) neighbor for each cluster
                self.clust_nn[c0] = ret

                # No further scores for this cluster.  Wait for the top level
                if ret is None: continue

                (seq_id_0, seq_id_1, dist) = ret

            # lookup seq_id_1 cluster
            c1 = self.seq_to_clust[seq_id_1]

            if min_dist is None or dist < min_dist:
                # Initialize a new min.
                min_pairs = {}
                min_pairs[c0] = {}
                min_pairs[c0][c1] = dist
                min_dist = dist

            elif dist == min_dist:
                # Add equivalent distance to list
                if not min_pairs.has_key(c0):
                    min_pairs[c0] = {}
                min_pairs[c0][c1] = dist

        return min_pairs


    def min_distance( self, x, y=None):
        """Return the pair of sequences with minimal distance between
        groups of sequences x and y"""
        
        if type(x) is int: x = [x]
        if type(y) is int: y = [y]

        q_blast = """SELECT seq_id_0, seq_id_1, e_value as dist
        FROM blast_hit_symmetric
        WHERE br_id='%(br_id)d'
        AND seq_id_0 in %(seq_id_0)s"""

        q_nc = """SELECT seq_id_0, seq_id_1, 1 - nc_score as dist
        FROM blast_hit_nc
        WHERE nc_id='%(nc_id)d'
        AND seq_id_0 in %(seq_id_0)s"""

        if y is None:
            cond = """ AND seq_id_1 not in %(seq_id_0)s
            GROUP by seq_id_0, seq_id_1
            ORDER BY dist
            LIMIT 1"""
        else:
            cond = """AND seq_id_1 in %(seq_id_1)s
            GROUP by seq_id_0, seq_id_1
            ORDER BY dist
            LIMIT 1"""

        if self.stype == 'e_value' and self.br_id is not None:
            q = q_blast + cond
        elif self.stype == 'nc_score' and self.nc_id is not None:
            q = q_nc + cond
        else:
            assert False, "Unknown stype or missing run id"

        seq_id_str = [ "(", "("]
        for (i,seq_ids) in enumerate((x,y)):
            if seq_ids is None: continue
            
            for seq_id in seq_ids:
                seq_id_str[i] += "%d," % seq_id
            # remove trailing comma
            seq_id_str[i] = seq_id_str[i][:-1]
            seq_id_str[i] += ")"
            
        self.bq.dbw.execute( q % {
            'br_id': self.br_id,
            'nc_id': self.nc_id,
            'seq_id_0': seq_id_str[0],
            'seq_id_1': seq_id_str[1]})

        d = self.bq.dbw.fetchone()

        return d
        
    
class complete_linkage(agglomerative_memory):
    """Complete-linkage clustering, using an in-memory distance
    matrix.  Fast."""

    def update_distance(self, c_new, clusters):

        # remove clusters from distance, nextbest
        for c in clusters:
            del self.diss[c]

        self.diss[c_new] = priority_dict()
        
        for c0 in self.diss:
            if c0 == c_new: continue
            
            # the maximal distance between any of the merged clusters to
            # each other cluster is now the distance between c_cnew and
            # that cluster
            maxdist = None
            for c1 in clusters:
                if c1 not in self.diss[c0]: continue

                dist = self.diss[c0].pop(c1)
                
                if maxdist is None or maxdist < dist:
                    maxdist = dist

            #assert dist is not None, "No distances for cluster: %s" % c0
            if maxdist is not None:
                self.diss[c0][c_new] = maxdist
                self.diss[c_new][c0] = maxdist

        return

class single_linkage_memory(agglomerative_memory):
    """Single-linkage clustering, using an in-memory distance matrix.
    Fast."""

    def update_distance(self, c_new, clusters):
        
        # remove clusters from distance, nextbest
        for c in clusters:
            del self.diss[c]
            
        self.diss[c_new] = priority_dict()
        
        for c0 in self.diss:
            if c0 == c_new: continue
            
            # the minimum distance between any of the merged clusters to
            # each other cluster is now the distance between c_cnew and
            # that cluster
            mindist = None
            for c1 in clusters:
                if c1 not in self.diss[c0]: continue
                
                dist = self.diss[c0].pop(c1)
                
                if mindist is None or mindist < dist:
                    mindist = dist
                    
            #assert dist is not None, "No distances for cluster: %s" % c0
            if mindist is not None:
                self.diss[c0][c_new] = mindist
                self.diss[c_new][c0] = mindist
                
        return

class average_linkage(agglomerative_memory):
    """Average-linkage clustering, using an in-memory distance matrix.
    Fast."""

    def update_distance(self, c_new, clusters):

        # remove clusters from distance, nextbest
        for c in clusters:
            del self.diss[c]

        self.diss[c_new] = priority_dict()
        
        for c0 in self.diss:
            if c0 == c_new: continue
            
            # the average distance between any of the merged clusters to
            # each other cluster is now the distance between c_cnew and
            # that cluster
            dist_sum = 0.0
            dist_cnt = 0
            for c1 in clusters:
                if c1 not in self.diss[c0]: continue

                dist = self.diss[c0].pop(c1)
                dist_sum += dist
                dist_cnt += 1

            #assert dist is not None, "No distances for cluster: %s" % c0
            if dist_cnt > 0:
                avg_dist = dist_sum / dist_cnt
                self.diss[c0][c_new] = avg_dist
                self.diss[c_new][c0] = avg_dist

        return
