#!/usr/bin/env python

# Jacob Joseph
# 2008 June 4
# Perform a variety of clustering techniques

import os, argparse
from JJcluster import cluster_sql, agglomerative, mcupgma, mcl, spici

def perform_hierarchical(stype, run_id, 
                         set_id_filter=None,
                         symmetric = True,
                         score_threshold=None,
                         comment=None,
                         param_list=None,
                         cluster_fn=None):
    if stype == 'nc_score':
        nc_id = run_id
        br_id = None
    else:
        nc_id = None
        br_id = run_id

    assert cluster_fn is not None, "cluster_class must be defined"

    clust = cluster_fn(br_id=br_id,
                       nc_id=nc_id,
                       stype=stype,
                       symmetric=symmetric,
                       set_id_filter=set_id_filter,
                       score_threshold=score_threshold,
                       param_list=param_list)
    clust.cluster()
    csql = cluster_sql.hcluster( stype=stype,
                                 br_id=br_id,
                                 nc_id=nc_id,
                                 symmetric=symmetric,
                                 set_id_filter=set_id_filter,
                                 score_threshold=score_threshold,
                                 comment=comment,
                                 params=str(param_list))
                                 
    csql.store_tree( clust.clust_to_seq.keys()[0])
    csql.nest_sets()

    return clust, csql
    
def perform_hierarchical_single(*args, **kwargs):
    """Perform agglomerative single-linkage hierarchical clustering."""
    kwargs['cluster_fn'] = agglomerative.single_linkage
    return perform_hierarchical(*args, **kwargs)

def perform_hierarchical_single_memory(*args, **kwargs):
    """Perform agglomerative single-linkage hierarchical clustering."""
    kwargs['cluster_fn'] = agglomerative.single_linkage_memory
    return perform_hierarchical(*args, **kwargs)

def perform_hierarchical_complete(*args, **kwargs):
    """Perform agglomerative complete-linkage hierarchical clustering."""
    kwargs['cluster_fn'] = agglomerative.complete_linkage
    return perform_hierarchical(*args, **kwargs)

def perform_hierarchical_average_memory( *args, **kwargs):
    """Perform agglomerative average-linkage hierarchical clustering,
    using an in-memory matrix."""
    kwargs['cluster_fn'] = agglomerative.average_linkage
    return perform_hierarchical(*args, **kwargs)

def perform_hierarchical_average( *args, **kwargs):
    """Perform agglomerative average-linkage hierarchical clustering.
    Uses MCUPGMA and is very fast."""
    kwargs['cluster_fn'] = mcupgma.mcupgma
    return perform_hierarchical(*args, **kwargs)

def perform_mcl(stype, run_id, 
                set_id_filter=None,
                score_threshold=None,
                comment=None,
                param_list=None):
    """Perform clustering using TribeMCL"""
    if stype == 'nc_score':
        nc_id = run_id
        br_id = None
    else:
        nc_id = None
        br_id = run_id

    clust = mcl.mcl( stype = stype,
                     br_id = br_id,
                     nc_id = nc_id,
                     score_threshold = score_threshold,
                     set_id_filter = set_id_filter,
                     param_list = param_list,
                     self_hits=True)

    clust.cluster()

    csql = cluster_sql.flatcluster( stype=stype,
                                    br_id=br_id,
                                    nc_id=nc_id,
                                    set_id_filter=set_id_filter,
                                    score_threshold=score_threshold,
                                    comment=comment,
                                    params=str(param_list))
    csql.store( clust.get_cluster_map() )
    
    return

def perform_spici(stype, run_id, 
                  set_id_filter=None,
                  score_threshold=None,
                  comment=None,
                  param_list=None):
    """Perform clustering using SPICi"""

    clust = spici.spici( stype, run_id,
                         set_id_filter = set_id_filter,
                         score_threshold = score_threshold,
                         param_list = param_list,
                         spici_path=os.path.expandvars("$HOME/Durand/JJ/cluster_eval/SPICi/src/"))
    clust.cluster()

    csql = cluster_sql.flatcluster( stype = stype,
                                    br_id = run_id if stype in ('e_value', 'bit_score') else None,
                                    nc_id = run_id if stype=='nc_score' else None,
                                    set_id_filter = set_id_filter,
                                    score_threshold = score_threshold,
                                    params = str(param_list),
                                    comment = comment)

    csql.store( clust.get_cluster_map())
    return clust, csql


def parse_options(methods):

    stypes = ['e_value', 'bit_score','nc_score']

    parser = argparse.ArgumentParser()

    parser.add_argument('method',
                        choices = methods,
                        metavar='method',
                        help="The clustering method to use.  One of %s" % methods)
    parser.add_argument('stype',
                        choices = stypes,
                        metavar = 'stype',
                        help="Score type.  One of %s" % stypes)
    parser.add_argument('run_id',
                        type=int,
                        metavar = 'run_id',
                        help="A br_id or nc_id, depending upon the score type.")

    parser.add_argument("--score_threshold",
                        type=float,
                        help="""Specify the 'weakest' edge to be included in the input network.  For
E-value, this is a maximum value; for other stypes, it is a minimum.""")
    
    parser.add_argument("--set_id_filter",
                        dest="set_id_filter",
                        type=int,
                        help="""Filter the input network by an existing set_id. Only edges both from
    incident to sequences in the set will be passed to the clustering
    algorithm.""")

    parser.add_argument("--symmetric",
                        default=True,
                        type=parse_boolean,
                        dest="symmetric",
                        help="""Use symmetric scores from the database.  Default=True (Applicable to Blast scores
    only.)""")

    parser.add_argument("--comment",
                        dest="comment",
                        help="""Populate the comment field in the database.  Will be prepended with
    the clustering method, and string variables such as %%(run_id)s,
    %%(br_id)s, %%(nc_id)s, %%(stype)s, %%(set_id_filter)s will be
    substituted.""")

    parser.add_argument("--clustering_args",
                        dest="clustering_args",
                        type = parse_clustering_args,
                        help="""Arguments for the clustering algorithm, if applicable.  Must be a
    quoted string to not be rendered as arguments for *this* program.""")

    
    args = parser.parse_args()

    return args

def parse_boolean(arg):
    """http://bugs.python.org/issue8538"""
    arg = arg.lower()
    if arg in ['0', 'f', 'false', 'no', 'off']:
        return False
    elif arg in ['1', 't', 'true', 'yes', 'on']:
        return True
    else:
        raise ValueError()

def parse_clustering_args(arg_str):
    args = arg_str.strip().split()
    return args

def main():
    method_map = {'single_linkage': perform_hierarchical_single,
                  'single_linkage_memory': perform_hierarchical_single_memory,
                  'complete_linkage': perform_hierarchical_complete,
                  'average_linkage_memory': perform_hierarchical_average_memory,
                  'average_linkage': perform_hierarchical_average,
                  'mcl': perform_mcl,
                  'spici': perform_spici}
    
    args = parse_options(methods=sorted(method_map.keys()))

    #print "args", args.symmetric

    if args.comment is not None:
        args.comment = "%(method)s: " + args.comment
        args.comment = args.comment % {'method': args.method,
                                       'stype': args.stype,
                                       'run_id': args.run_id,
                                       'set_id_filter': args.set_id_filter,
                                       'symmetric': args.symmetric,
                                       'score_threshold': args.score_threshold}
    else:
        args.comment = args.method

    clust, csql = method_map[args.method]( stype = args.stype,
                                          run_id = args.run_id,
                                          set_id_filter = args.set_id_filter,
                                          symmetric = args.symmetric,
                                          score_threshold = args.score_threshold,
                                          comment = args.comment,
                                          param_list = args.clustering_args)
    return clust, csql

if __name__ == "__main__":
    clust, csql = main()
    #br_id = 95
    #nc_id = 736
    #br_id = 75
    #nc_id = 698

    #nc_id = 750 # cluster -> cluster
    #nc_id = 746 # cluster -> all
    #br_id = 97  # cluster -> cluster
    #set_id_filter = 105  # buggy cluster
    #set_id_filter = 107   # new (jan10) cluster
    #stype='nc_score'
    
    #set_id_filter = 108  # 12 genomes
    #set_id_filter = None
    #set_id_filter = 109 # human and mouse
    #set_id_filter = 111 # S cerevisiae
    #set_id_filter = 112  # human
    #br_id = 104 # 12 genomes, compositional
    #br_id = 105 # 12 genomes, not compositional
    #nc_id = 777 # non-symmetric. 12 genomes
    #nc_id = 779 # symmetric. 12 genomes
    #nc_id = 780 # non-symmetric, not compositional.  12 genomes

    #nc_id = 781 # Symmetric, not composotional. 12 genomes

    #br_id = 102 # full 600k set
    #nc_id = 789

    #clust, csql = perform_hierarchical_average(br_id, nc_id, stype,
    #                                           set_id_filter,
    #                                           param_dict={'-jobs': 3,
    #                                                       '-M': int(20e6)})

    #clust, csql = perform_hierarchical_single(br_id, nc_id, stype,
    #                                          set_id_filter)

    #perform_hierarchical(br_id, nc_id, 'e_value',
    #                     "single linkage.")

    #clust, csql = perform_hierarchical_single(br_id, nc_id, 'nc_score',
    #                                    comment='Sibson single linkage')


    #clust, csql = perform_hierarchical_complete(br_id, nc_id, 'nc_score',
    #                                            "complete linkage.")
    #clust, csql = perform_hierarchical_complete(br_id, nc_id, 'e_value',
    #                                            "complete linkage.")

    #clust, csql = perform_hierarchical_average(br_id, nc_id, 'e_value')
    #clust, csql = perform_hierarchical_average(br_id, nc_id, 'nc_score')
    
    #mcl_params = {'main_inflation': 3, 'threads': 4, 'scheme': 7}
    #mcl_params = {'main_inflation': 2, 'threads': 4}
    #mcl_params = {'main_inflation': 4, 'threads': 4, 'scheme': 7}
    #perform_mcl( br_id, nc_id, 'e_value', "", mcl_params)
    #perform_mcl( br_id, nc_id, 'nc_score', "", mcl_params)
    
    #mcl_params = {'main_inflation': 4.5, 'threads': 4, 'scheme': 7}
    #perform_mcl( br_id, nc_id, 'e_value', "", mcl_params)
    #perform_mcl( br_id, nc_id, 'nc_score', "", mcl_params)

    #mcl_params = {'main_inflation': 5, 'threads': 4, 'scheme': 7}
    #perform_mcl( br_id, nc_id, 'e_value', "", mcl_params)
    #perform_mcl( br_id, nc_id, 'nc_score', "", mcl_params)
    

    #clust, csql = perform_spici( run_id = 782, # human, mouse PPOD dataset
    #                             stype = 'nc_score',
    #                             comment = 'default',
    #                             #score_threshold = 0.8
    #                             param_dict={'-s': '1',   # minimum cluster size. default 2
    #                                         #'-d': '0.9', # minimum density threshold. default 0.5
    #                                         #'-g': '0.9'  # minimum increment ratio. default 0.5
    #                                         }
    #                             )
    
                                 
